Skip to content

Commit

Permalink
Add support for formats without infinities
Browse files Browse the repository at this point in the history
  • Loading branch information
mfasi committed Jun 2, 2024
1 parent 99665da commit 9ae6e80
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 158 deletions.
42 changes: 31 additions & 11 deletions mex/cpfloat.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void mexFunction(int nlhs,
fpopts->emin = -14;
fpopts->emax = 15;
fpopts->explim = CPFLOAT_EXPRANGE_TARG;
fpopts->infinity = CPFLOAT_INF_USE;
fpopts->round = CPFLOAT_RND_NE;
fpopts->saturation = CPFLOAT_SAT_NO;
fpopts->subnormal = CPFLOAT_SUBN_USE;
Expand All @@ -57,6 +58,7 @@ void mexFunction(int nlhs,
/* Parse second argument and populate fpopts structure. */
if (nrhs > 1) {
bool is_subn_rnd_default = false;
bool is_inf_no_default = false;
if(!mxIsEmpty(prhs[1]) && !mxIsStruct(prhs[1])) {
mexErrMsgIdAndTxt("cpfloat:invalidstruct",
"Second argument must be a struct.");
Expand All @@ -83,6 +85,7 @@ void mexFunction(int nlhs,
fpopts->precision = 4;
fpopts->emin = -6;
fpopts->emax = 8;
is_inf_no_default = true;
} else if (!strcmp(fpopts->format, "q52") ||
!strcmp(fpopts->format, "fp8-e5m2") ||
!strcmp(fpopts->format, "E5M2")) {
Expand Down Expand Up @@ -137,6 +140,7 @@ void mexFunction(int nlhs,
mexErrMsgIdAndTxt("cpfloat:invalidformat",
"Invalid floating-point format specified.");
}

/* Set default values to be compatible with MATLAB chop. */
tmp = mxGetField(prhs[1], 0, "subnormal");
if (tmp != NULL) {
Expand All @@ -147,32 +151,43 @@ void mexFunction(int nlhs,
} else {
if (is_subn_rnd_default)
fpopts->subnormal = CPFLOAT_SUBN_RND; /* Default for bfloat16. */
else
fpopts->subnormal = CPFLOAT_SUBN_USE;
}

tmp = mxGetField(prhs[1], 0, "explim");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->explim = 1;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->explim = *((double *)mxGetData(tmp));
}

tmp = mxGetField(prhs[1], 0, "infinity");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->infinity = CPFLOAT_INF_USE;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->infinity = *((double *)mxGetData(tmp));
} else {
if (is_inf_no_default)
fpopts->infinity = CPFLOAT_INF_NO; /* Default for E4M5. */
}

tmp = mxGetField(prhs[1], 0, "round");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->round = CPFLOAT_RND_NE;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->round = *((double *)mxGetData(tmp));
}

tmp = mxGetField(prhs[1], 0, "saturation");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->saturation = CPFLOAT_SAT_NO;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->saturation = *((double *)mxGetData(tmp));
} else {
fpopts->saturation = CPFLOAT_SAT_NO;
}

tmp = mxGetField(prhs[1], 0, "subnormal");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
Expand Down Expand Up @@ -313,11 +328,11 @@ void mexFunction(int nlhs,

/* Allocate and return second output. */
if (nlhs > 1) {
const char* field_names[] = {"format", "params", "explim",
const char* field_names[] = {"format", "params", "explim", "infinity",
"round", "saturation", "subnormal",
"flip", "p"};
mwSize dims[2] = {1, 1};
plhs[1] = mxCreateStructArray(2, dims, 8, field_names);
plhs[1] = mxCreateStructArray(2, dims, 9, field_names);
mxSetFieldByNumber(plhs[1], 0, 0, mxCreateString(fpopts->format));

mxArray *outparams = mxCreateDoubleMatrix(1,3,mxREAL);
Expand All @@ -332,30 +347,35 @@ void mexFunction(int nlhs,
outexplimptr[0] = fpopts->explim;
mxSetFieldByNumber(plhs[1], 0, 2, outexplim);

mxArray *outinfinity = mxCreateDoubleMatrix(1, 1, mxREAL);
double *outinfinityptr = mxGetData(outinfinity);
outinfinityptr[0] = fpopts->infinity;
mxSetFieldByNumber(plhs[1], 0, 3, outinfinity);

mxArray *outround = mxCreateDoubleMatrix(1,1,mxREAL);
double *outroundptr = mxGetData(outround);
outroundptr[0] = fpopts->round;
mxSetFieldByNumber(plhs[1], 0, 3, outround);
mxSetFieldByNumber(plhs[1], 0, 4, outround);

mxArray *outsaturation = mxCreateDoubleMatrix(1,1,mxREAL);
double *outsaturationptr = mxGetData(outsaturation);
outsaturationptr[0] = fpopts->saturation;
mxSetFieldByNumber(plhs[1], 0, 4, outsaturation);
mxSetFieldByNumber(plhs[1], 0, 5, outsaturation);

mxArray *outsubnormal = mxCreateDoubleMatrix(1,1,mxREAL);
double *outsubnormalptr = mxGetData(outsubnormal);
outsubnormalptr[0] = fpopts->subnormal;
mxSetFieldByNumber(plhs[1], 0, 5, outsubnormal);
mxSetFieldByNumber(plhs[1], 0, 6, outsubnormal);

mxArray *outflip = mxCreateDoubleMatrix(1,1,mxREAL);
double *outflipptr = mxGetData(outflip);
outflipptr[0] = fpopts->flip;
mxSetFieldByNumber(plhs[1], 0, 6, outflip);
mxSetFieldByNumber(plhs[1], 0, 7, outflip);

mxArray *outp = mxCreateDoubleMatrix(1,1,mxREAL);
double *outpptr = mxGetData(outp);
outpptr[0] = fpopts->p;
mxSetFieldByNumber(plhs[1], 0, 7, outp);
mxSetFieldByNumber(plhs[1], 0, 8, outp);

}
if (nlhs > 2)
Expand Down
5 changes: 5 additions & 0 deletions mex/cpfloat.m
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@
% this field is set to 0, and the exponent range of the format specified in
% FPOPTS.format otherwise. The default value for this field is 1.
%
% * The scalar FPOPTS.infinity specifies whether infinities are supported. The
% target floating-point format will support infinities if this field is set
% to 1, and they will be replaced by NaNs otherwise. The default value for
% this field is 0 if the target format is 'E4M3' and 1 otherwise.
%
% * The scalar FPOPTS.round specifies the rounding mode. Possible values are:
% -1 for round-to-nearest with ties-to-away;
% 0 for round-to-nearest with ties-to-zero;
Expand Down
19 changes: 19 additions & 0 deletions src/cpfloat_definitions.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* defines the enumerated types
*
* + @ref cpfloat_explim_t,
* + @ref cpfloat_infinity_t,
* + @ref cpfloat_rounding_t,
* + @ref cpfloat_saturation_t,
* + @ref cpfloat_softerr_t,
Expand Down Expand Up @@ -63,6 +64,16 @@ typedef enum {
CPFLOAT_EXPRANGE_TARG = 1
} cpfloat_explim_t;

/**
* @brief Infinity support modes available in CPFloat.
*/
typedef enum {
/** Use infinities in target format. */
CPFLOAT_INF_NO = 0,
/** Replace infinities with NaNs in target format. */
CPFLOAT_INF_USE = 1,
} cpfloat_infinity_t;

/**
* @brief Rounding modes available in CPFloat.
*/
Expand Down Expand Up @@ -234,6 +245,14 @@ typedef struct {
* `CPFLOAT_EXPRANGE_STOR`.
*/
cpfloat_explim_t explim;
/**
* @brief Support for infinities in target format.
*
* @details If this field is set to `CPFLOAT_INF_USE`, the target format
* supports signed infinities. If the field is set to `CPFLOAT_INF_NO`,
* infinities are replaced with a quiet NaN.
*/
cpfloat_infinity_t infinity;
/**
* @brief Rounding mode to be used for the conversion.
*
Expand Down
15 changes: 12 additions & 3 deletions src/cpfloat_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ typedef struct {
cpfloat_precision_t precision;
cpfloat_exponent_t emin;
cpfloat_exponent_t emax;
cpfloat_infinity_t infinity;
cpfloat_rounding_t round;
cpfloat_saturation_t saturation;
cpfloat_subnormal_t subnormal;
Expand Down Expand Up @@ -322,13 +323,19 @@ static inline FPPARAMS COMPUTE_GLOBAL_PARAMS(const optstruct *fpopts,

FPTYPE xmax = ldexp(1., emax) * (2-ldexp(1., 1-precision));
FPTYPE xbnd = ldexp(1., emax) * (2-ldexp(1., -precision));
FPTYPE ofvalue = (fpopts->saturation == CPFLOAT_SAT_USE) ? xmax : INFINITY;
/*
* Here, fpopts->saturation takes precedence over fpopts->infinity. Therefore,
* when saturation arithmetic is used, infinities are not produced even when
* the target format supports them.
*/
FPTYPE ofvalue = (fpopts->saturation == CPFLOAT_SAT_USE) ? xmax :
(fpopts->infinity == CPFLOAT_INF_USE ? INFINITY : NAN);

/* Bitmasks. */
INTTYPE leadmask = FULLMASK << (DEFPREC-precision); /* To keep. */
INTTYPE trailmask = leadmask ^ FULLMASK; /* To discard. */

FPPARAMS params = {precision, emin, emax, fpopts->round,
FPPARAMS params = {precision, emin, emax, fpopts->infinity, fpopts->round,
fpopts->saturation, fpopts->subnormal,
ftzthreshold, ofvalue, xmin, xmax, xbnd,
leadmask, trailmask, NULL, NULL};
Expand Down Expand Up @@ -656,7 +663,9 @@ static inline void UPDATE_LOCAL_PARAMS(const FPTYPE *A,
numelem, p, lp) \
PARALLEL_STRING(PARALLEL) \
{ \
if (p->emax == DEFEMAX && p->saturation == CPFLOAT_SAT_NO) { \
if (p->emax == DEFEMAX \
&& p->saturation == CPFLOAT_SAT_NO \
&& p->infinity == CPFLOAT_INF_USE) { \
FOR_STRING(PARALLEL) \
for (size_t i=0; i<numelem; i++) { \
DEPARENTHESIZE_MAYBE(PREPROC) \
Expand Down
16 changes: 16 additions & 0 deletions test/cpfloat_test.m
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@
assert_eq(fp.subnormal,1)
assert_eq(fp.params, [53 -1022 1023])

clear fp
fp.format = 'E4M3';
[~,options] = cpfloat(pi,fp);
assert_eq(options.format,'E4M3')
assert_eq(options.infinity,0)
assert_eq(options.params, [4 -6 8])
[~,fp] = cpfloat;
assert_eq(fp.format,'E4M3')
assert_eq(fp.infinity,0)
assert_eq(fp.params, [4 -6 8])


clear fp
fp.format = 'bfloat16';
[~,options] = cpfloat(pi,fp);
Expand Down Expand Up @@ -323,6 +335,9 @@
end

% Infinities tests.
[~,fpopts] = cpfloat;
prev_infinity = fpopts.infinity;
options.infinity = 1;
options.saturation = 0;
for j = 1:6
options.round = j;
Expand Down Expand Up @@ -400,6 +415,7 @@
c = cpfloat(x,options);
c_expected = [0 0 x(3:5) inf 1 1];
assert_eq(c,c_expected)
options.infinity = prev_infinity;

% Smallest normal number and spacing between the subnormal numbers.
y = xmin; delta = xmin*2^(1-p);
Expand Down
Loading

0 comments on commit 9ae6e80

Please sign in to comment.