diff --git a/mex/cpfloat.c b/mex/cpfloat.c index 1e5dd62..15a717c 100644 --- a/mex/cpfloat.c +++ b/mex/cpfloat.c @@ -42,6 +42,7 @@ void mexFunction(int nlhs, fpopts->emin = -14; fpopts->emax = 15; fpopts->explim = CPFLOAT_EXPRANGE_TARG; + fpopts->infinity = CPFLOAT_INF_USE; fpopts->round = CPFLOAT_RND_NE; fpopts->saturation = CPFLOAT_SAT_NO; fpopts->subnormal = CPFLOAT_SUBN_USE; @@ -57,6 +58,7 @@ void mexFunction(int nlhs, /* Parse second argument and populate fpopts structure. */ if (nrhs > 1) { bool is_subn_rnd_default = false; + bool is_inf_no_default = false; if(!mxIsEmpty(prhs[1]) && !mxIsStruct(prhs[1])) { mexErrMsgIdAndTxt("cpfloat:invalidstruct", "Second argument must be a struct."); @@ -83,6 +85,7 @@ void mexFunction(int nlhs, fpopts->precision = 4; fpopts->emin = -6; fpopts->emax = 8; + is_inf_no_default = true; } else if (!strcmp(fpopts->format, "q52") || !strcmp(fpopts->format, "fp8-e5m2") || !strcmp(fpopts->format, "E5M2")) { @@ -137,6 +140,7 @@ void mexFunction(int nlhs, mexErrMsgIdAndTxt("cpfloat:invalidformat", "Invalid floating-point format specified."); } + /* Set default values to be compatible with MATLAB chop. */ tmp = mxGetField(prhs[1], 0, "subnormal"); if (tmp != NULL) { @@ -147,9 +151,8 @@ void mexFunction(int nlhs, } else { if (is_subn_rnd_default) fpopts->subnormal = CPFLOAT_SUBN_RND; /* Default for bfloat16. */ - else - fpopts->subnormal = CPFLOAT_SUBN_USE; } + tmp = mxGetField(prhs[1], 0, "explim"); if (tmp != NULL) { if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0) @@ -157,6 +160,18 @@ void mexFunction(int nlhs, else if (mxGetClassID(tmp) == mxDOUBLE_CLASS) fpopts->explim = *((double *)mxGetData(tmp)); } + + tmp = mxGetField(prhs[1], 0, "infinity"); + if (tmp != NULL) { + if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0) + fpopts->infinity = CPFLOAT_INF_USE; + else if (mxGetClassID(tmp) == mxDOUBLE_CLASS) + fpopts->infinity = *((double *)mxGetData(tmp)); + } else { + if (is_inf_no_default) + fpopts->infinity = CPFLOAT_INF_NO; /* Default for E4M5. */ + } + tmp = mxGetField(prhs[1], 0, "round"); if (tmp != NULL) { if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0) @@ -164,15 +179,15 @@ void mexFunction(int nlhs, else if (mxGetClassID(tmp) == mxDOUBLE_CLASS) fpopts->round = *((double *)mxGetData(tmp)); } + tmp = mxGetField(prhs[1], 0, "saturation"); if (tmp != NULL) { if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0) fpopts->saturation = CPFLOAT_SAT_NO; else if (mxGetClassID(tmp) == mxDOUBLE_CLASS) fpopts->saturation = *((double *)mxGetData(tmp)); - } else { - fpopts->saturation = CPFLOAT_SAT_NO; } + tmp = mxGetField(prhs[1], 0, "subnormal"); if (tmp != NULL) { if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0) @@ -313,11 +328,11 @@ void mexFunction(int nlhs, /* Allocate and return second output. */ if (nlhs > 1) { - const char* field_names[] = {"format", "params", "explim", + const char* field_names[] = {"format", "params", "explim", "infinity", "round", "saturation", "subnormal", "flip", "p"}; mwSize dims[2] = {1, 1}; - plhs[1] = mxCreateStructArray(2, dims, 8, field_names); + plhs[1] = mxCreateStructArray(2, dims, 9, field_names); mxSetFieldByNumber(plhs[1], 0, 0, mxCreateString(fpopts->format)); mxArray *outparams = mxCreateDoubleMatrix(1,3,mxREAL); @@ -332,30 +347,35 @@ void mexFunction(int nlhs, outexplimptr[0] = fpopts->explim; mxSetFieldByNumber(plhs[1], 0, 2, outexplim); + mxArray *outinfinity = mxCreateDoubleMatrix(1, 1, mxREAL); + double *outinfinityptr = mxGetData(outinfinity); + outinfinityptr[0] = fpopts->infinity; + mxSetFieldByNumber(plhs[1], 0, 3, outinfinity); + mxArray *outround = mxCreateDoubleMatrix(1,1,mxREAL); double *outroundptr = mxGetData(outround); outroundptr[0] = fpopts->round; - mxSetFieldByNumber(plhs[1], 0, 3, outround); + mxSetFieldByNumber(plhs[1], 0, 4, outround); mxArray *outsaturation = mxCreateDoubleMatrix(1,1,mxREAL); double *outsaturationptr = mxGetData(outsaturation); outsaturationptr[0] = fpopts->saturation; - mxSetFieldByNumber(plhs[1], 0, 4, outsaturation); + mxSetFieldByNumber(plhs[1], 0, 5, outsaturation); mxArray *outsubnormal = mxCreateDoubleMatrix(1,1,mxREAL); double *outsubnormalptr = mxGetData(outsubnormal); outsubnormalptr[0] = fpopts->subnormal; - mxSetFieldByNumber(plhs[1], 0, 5, outsubnormal); + mxSetFieldByNumber(plhs[1], 0, 6, outsubnormal); mxArray *outflip = mxCreateDoubleMatrix(1,1,mxREAL); double *outflipptr = mxGetData(outflip); outflipptr[0] = fpopts->flip; - mxSetFieldByNumber(plhs[1], 0, 6, outflip); + mxSetFieldByNumber(plhs[1], 0, 7, outflip); mxArray *outp = mxCreateDoubleMatrix(1,1,mxREAL); double *outpptr = mxGetData(outp); outpptr[0] = fpopts->p; - mxSetFieldByNumber(plhs[1], 0, 7, outp); + mxSetFieldByNumber(plhs[1], 0, 8, outp); } if (nlhs > 2) diff --git a/mex/cpfloat.m b/mex/cpfloat.m index 2dbf9b0..3e27dea 100644 --- a/mex/cpfloat.m +++ b/mex/cpfloat.m @@ -45,6 +45,11 @@ % this field is set to 0, and the exponent range of the format specified in % FPOPTS.format otherwise. The default value for this field is 1. % +% * The scalar FPOPTS.infinity specifies whether infinities are supported. The +% target floating-point format will support infinities if this field is set +% to 1, and they will be replaced by NaNs otherwise. The default value for +% this field is 0 if the target format is 'E4M3' and 1 otherwise. +% % * The scalar FPOPTS.round specifies the rounding mode. Possible values are: % -1 for round-to-nearest with ties-to-away; % 0 for round-to-nearest with ties-to-zero; diff --git a/src/cpfloat_definitions.h b/src/cpfloat_definitions.h index 354d7b1..9493f4d 100644 --- a/src/cpfloat_definitions.h +++ b/src/cpfloat_definitions.h @@ -9,6 +9,7 @@ * defines the enumerated types * * + @ref cpfloat_explim_t, + * + @ref cpfloat_infinity_t, * + @ref cpfloat_rounding_t, * + @ref cpfloat_saturation_t, * + @ref cpfloat_softerr_t, @@ -63,6 +64,16 @@ typedef enum { CPFLOAT_EXPRANGE_TARG = 1 } cpfloat_explim_t; +/** + * @brief Infinity support modes available in CPFloat. + */ +typedef enum { + /** Use infinities in target format. */ + CPFLOAT_INF_NO = 0, + /** Replace infinities with NaNs in target format. */ + CPFLOAT_INF_USE = 1, +} cpfloat_infinity_t; + /** * @brief Rounding modes available in CPFloat. */ @@ -234,6 +245,14 @@ typedef struct { * `CPFLOAT_EXPRANGE_STOR`. */ cpfloat_explim_t explim; + /** + * @brief Support for infinities in target format. + * + * @details If this field is set to `CPFLOAT_INF_USE`, the target format + * supports signed infinities. If the field is set to `CPFLOAT_INF_NO`, + * infinities are replaced with a quiet NaN. + */ + cpfloat_infinity_t infinity; /** * @brief Rounding mode to be used for the conversion. * diff --git a/src/cpfloat_template.h b/src/cpfloat_template.h index 28cf71c..0767139 100644 --- a/src/cpfloat_template.h +++ b/src/cpfloat_template.h @@ -322,7 +322,14 @@ static inline FPPARAMS COMPUTE_GLOBAL_PARAMS(const optstruct *fpopts, FPTYPE xmax = ldexp(1., emax) * (2-ldexp(1., 1-precision)); FPTYPE xbnd = ldexp(1., emax) * (2-ldexp(1., -precision)); - FPTYPE ofvalue = (fpopts->saturation == CPFLOAT_SAT_USE) ? xmax : INFINITY; + /* + * Here, fpopts->saturation takes precedence over fpopts->infinity. Therefore, + * when saturation arithmetic is used, infinities are not produced even when + * the target format supports them. + */ + FPTYPE ofvalue = (fpopts->saturation == CPFLOAT_SAT_USE) ? + xmax : + (fpopts->infinity == CPFLOAT_INF_USE)?INFINITY:NAN; /* Bitmasks. */ INTTYPE leadmask = FULLMASK << (DEFPREC-precision); /* To keep. */ diff --git a/test/cpfloat_test.m b/test/cpfloat_test.m index b436e63..d44b4fd 100644 --- a/test/cpfloat_test.m +++ b/test/cpfloat_test.m @@ -126,6 +126,18 @@ assert_eq(fp.subnormal,1) assert_eq(fp.params, [53 -1022 1023]) + clear fp + fp.format = 'E4M3'; + [~,options] = cpfloat(pi,fp); + assert_eq(options.format,'E4M3') + assert_eq(options.infinity,0) + assert_eq(options.params, [4 -6 8]) + [~,fp] = cpfloat; + assert_eq(fp.format,'E4M3') + assert_eq(fp.infinity,0) + assert_eq(fp.params, [4 -6 8]) + + clear fp fp.format = 'bfloat16'; [~,options] = cpfloat(pi,fp); @@ -323,6 +335,9 @@ end % Infinities tests. + [~,fpopts] = cpfloat; + prev_infinity = fpopts.infinity; + options.infinity = 1; options.saturation = 0; for j = 1:6 options.round = j; @@ -400,6 +415,7 @@ c = cpfloat(x,options); c_expected = [0 0 x(3:5) inf 1 1]; assert_eq(c,c_expected) + options.infinity = prev_infinity; % Smallest normal number and spacing between the subnormal numbers. y = xmin; delta = xmin*2^(1-p); diff --git a/test/cpfloat_test.ts b/test/cpfloat_test.ts index df68de5..a30bd39 100644 --- a/test/cpfloat_test.ts +++ b/test/cpfloat_test.ts @@ -55,6 +55,7 @@ void fpopts_setup(void) { fpopts = malloc(sizeof(optstruct)); fpopts->round = 1; fpopts->explim = CPFLOAT_EXPRANGE_STOR; + fpopts->infinity = CPFLOAT_INF_USE; fpopts->saturation = CPFLOAT_SAT_NO; fpopts->flip = 0;