Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make E3M4 format fully compliant with the OFP8 specification #10

Merged
merged 3 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 72 additions & 21 deletions mex/cpfloat.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ void mexFunction(int nlhs,
fpopts->precision = 11;
fpopts->emin = -14;
fpopts->emax = 15;
fpopts->subnormal = CPFLOAT_SUBN_USE;
fpopts->explim = CPFLOAT_EXPRANGE_TARG;
fpopts->infinity = CPFLOAT_INF_USE;
fpopts->round = CPFLOAT_RND_NE;
fpopts->flip = CPFLOAT_NO_SOFTERR;
fpopts->saturation = CPFLOAT_SAT_NO;
fpopts->subnormal = CPFLOAT_SUBN_USE;

fpopts->flip = CPFLOAT_SOFTERR_NO;
fpopts->p = 0.5;

fpopts->bitseed = NULL;
fpopts->randseedf = NULL;
fpopts->randseed = NULL;
Expand All @@ -54,6 +58,7 @@ void mexFunction(int nlhs,
/* Parse second argument and populate fpopts structure. */
if (nrhs > 1) {
bool is_subn_rnd_default = false;
bool is_inf_no_default = false;
if(!mxIsEmpty(prhs[1]) && !mxIsStruct(prhs[1])) {
mexErrMsgIdAndTxt("cpfloat:invalidstruct",
"Second argument must be a struct.");
Expand All @@ -62,7 +67,7 @@ void mexFunction(int nlhs,

if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
/* Use default format, for compatibility with chop. */
/* Set default format, for compatibility with chop. */
strcpy(fpopts->format, "h");
else if (mxGetClassID(tmp) == mxCHAR_CLASS)
strcpy(fpopts->format, mxArrayToString(tmp));
Expand All @@ -80,6 +85,7 @@ void mexFunction(int nlhs,
fpopts->precision = 4;
fpopts->emin = -6;
fpopts->emax = 8;
is_inf_no_default = true;
} else if (!strcmp(fpopts->format, "q52") ||
!strcmp(fpopts->format, "fp8-e5m2") ||
!strcmp(fpopts->format, "E5M2")) {
Expand Down Expand Up @@ -134,6 +140,7 @@ void mexFunction(int nlhs,
mexErrMsgIdAndTxt("cpfloat:invalidformat",
"Invalid floating-point format specified.");
}

/* Set default values to be compatible with MATLAB chop. */
tmp = mxGetField(prhs[1], 0, "subnormal");
if (tmp != NULL) {
Expand All @@ -144,27 +151,60 @@ void mexFunction(int nlhs,
} else {
if (is_subn_rnd_default)
fpopts->subnormal = CPFLOAT_SUBN_RND; /* Default for bfloat16. */
else
fpopts->subnormal = CPFLOAT_SUBN_USE;
}

tmp = mxGetField(prhs[1], 0, "explim");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->explim = 1;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->explim = *((double *)mxGetData(tmp));
}

tmp = mxGetField(prhs[1], 0, "infinity");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->infinity = CPFLOAT_INF_USE;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->infinity = *((double *)mxGetData(tmp));
} else {
if (is_inf_no_default)
fpopts->infinity = CPFLOAT_INF_NO; /* Default for E4M5. */
}

tmp = mxGetField(prhs[1], 0, "round");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->round = CPFLOAT_RND_NE;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->round = *((double *)mxGetData(tmp));
}

tmp = mxGetField(prhs[1], 0, "saturation");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->saturation = CPFLOAT_SAT_NO;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->saturation = *((double *)mxGetData(tmp));
}

tmp = mxGetField(prhs[1], 0, "subnormal");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->subnormal = CPFLOAT_SUBN_USE;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->subnormal = *((double *)mxGetData(tmp));
} else {
if (is_subn_rnd_default)
fpopts->subnormal = CPFLOAT_SUBN_RND; /* Default for bfloat16. */
else
fpopts->subnormal = CPFLOAT_SUBN_USE;
}

tmp = mxGetField(prhs[1], 0, "flip");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->flip = CPFLOAT_NO_SOFTERR;
fpopts->flip = CPFLOAT_SOFTERR_NO;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->flip = *((double *)mxGetData(tmp));
}
Expand Down Expand Up @@ -288,10 +328,11 @@ void mexFunction(int nlhs,

/* Allocate and return second output. */
if (nlhs > 1) {
const char* field_names[] = {"format", "params", "subnormal", "round",
"flip", "p", "explim"};
const char* field_names[] = {"format", "params", "explim", "infinity",
"round", "saturation", "subnormal",
"flip", "p"};
mwSize dims[2] = {1, 1};
plhs[1] = mxCreateStructArray(2, dims, 7, field_names);
plhs[1] = mxCreateStructArray(2, dims, 9, field_names);
mxSetFieldByNumber(plhs[1], 0, 0, mxCreateString(fpopts->format));

mxArray *outparams = mxCreateDoubleMatrix(1,3,mxREAL);
Expand All @@ -301,30 +342,40 @@ void mexFunction(int nlhs,
outparamsptr[2] = fpopts->emax;
mxSetFieldByNumber(plhs[1], 0, 1, outparams);

mxArray *outsubnormal = mxCreateDoubleMatrix(1,1,mxREAL);
double *outsubnormalptr = mxGetData(outsubnormal);
outsubnormalptr[0] = fpopts->subnormal;
mxSetFieldByNumber(plhs[1], 0, 2, outsubnormal);
mxArray *outexplim = mxCreateDoubleMatrix(1, 1, mxREAL);
double *outexplimptr = mxGetData(outexplim);
outexplimptr[0] = fpopts->explim;
mxSetFieldByNumber(plhs[1], 0, 2, outexplim);

mxArray *outinfinity = mxCreateDoubleMatrix(1, 1, mxREAL);
double *outinfinityptr = mxGetData(outinfinity);
outinfinityptr[0] = fpopts->infinity;
mxSetFieldByNumber(plhs[1], 0, 3, outinfinity);

mxArray *outround = mxCreateDoubleMatrix(1,1,mxREAL);
double *outroundptr = mxGetData(outround);
outroundptr[0] = fpopts->round;
mxSetFieldByNumber(plhs[1], 0, 3, outround);
mxSetFieldByNumber(plhs[1], 0, 4, outround);

mxArray *outsaturation = mxCreateDoubleMatrix(1,1,mxREAL);
double *outsaturationptr = mxGetData(outsaturation);
outsaturationptr[0] = fpopts->saturation;
mxSetFieldByNumber(plhs[1], 0, 5, outsaturation);

mxArray *outsubnormal = mxCreateDoubleMatrix(1,1,mxREAL);
double *outsubnormalptr = mxGetData(outsubnormal);
outsubnormalptr[0] = fpopts->subnormal;
mxSetFieldByNumber(plhs[1], 0, 6, outsubnormal);

mxArray *outflip = mxCreateDoubleMatrix(1,1,mxREAL);
double *outflipptr = mxGetData(outflip);
outflipptr[0] = fpopts->flip;
mxSetFieldByNumber(plhs[1], 0, 4, outflip);
mxSetFieldByNumber(plhs[1], 0, 7, outflip);

mxArray *outp = mxCreateDoubleMatrix(1,1,mxREAL);
double *outpptr = mxGetData(outp);
outpptr[0] = fpopts->p;
mxSetFieldByNumber(plhs[1], 0, 5, outp);

mxArray *outexplim = mxCreateDoubleMatrix(1,1,mxREAL);
double *outexplimptr = mxGetData(outexplim);
outexplimptr[0] = fpopts->explim;
mxSetFieldByNumber(plhs[1], 0, 6, outexplim);
mxSetFieldByNumber(plhs[1], 0, 8, outp);

}
if (nlhs > 2)
Expand Down
20 changes: 15 additions & 5 deletions mex/cpfloat.m
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@
% the target format, respectively. The default value of this field is
% the vector [11,-14,15].
%
% * The scalar FPOPTS.subnormal specifies the support for subnormal numbers.
% The target floating-point format will not support subnormal numbers if
% this field is set to 0, and will support them otherwise. The default value
% for this field is 0 if the target format is 'bfloat16' and 1 otherwise.
%
% * The scalar FPOPTS.explim specifies the support for an extended exponent
% range. The target floating-point format will have the exponent range of
% the storage format ('single' or 'double', depending on the class of X) if
% this field is set to 0, and the exponent range of the format specified in
% FPOPTS.format otherwise. The default value for this field is 1.
%
% * The scalar FPOPTS.infinity specifies whether infinities are supported. The
% target floating-point format will support infinities if this field is set
% to 1, and they will be replaced by NaNs otherwise. The default value for
% this field is 0 if the target format is 'E4M3' and 1 otherwise.
%
% * The scalar FPOPTS.round specifies the rounding mode. Possible values are:
% -1 for round-to-nearest with ties-to-away;
% 0 for round-to-nearest with ties-to-zero;
Expand All @@ -63,6 +63,16 @@
% Any other value results in no rounding. The default value for this field
% is 1.
%
% * The scalar FPOPTS.saturation specifies whether saturation arithmetic is in
% use. On overflow, the target floating-point format will use the largest
% representable floating-point if this field is set to 0, and infinity
% otherwise. The default value for this field is 0.

% * The scalar FPOPTS.subnormal specifies the support for subnormal numbers.
% The target floating-point format will not support subnormal numbers if
% this field is set to 0, and will support them otherwise. The default value
% for this field is 0 if the target format is 'bfloat16' and 1 otherwise.
%
% * The scalar FPOPTS.flip specifies whether the function should simulate the
% occurrence of a single bit flip striking the floating-point representation
% of elements of Y. Possible values are:
Expand Down
2 changes: 1 addition & 1 deletion src/cpfloat_binary32.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ static inline int cpf_fmaf(float *X, const float *A, const float *B,
#define INTSUFFIX U

#define DEFPREC 24
#define DEFEMAX 127
#define DEFEMIN -126
#define DEFEMAX 127
#define NLEADBITS 9
#define NBITS 32
#define FULLMASK 0xFFFFFFFFU
Expand Down
2 changes: 1 addition & 1 deletion src/cpfloat_binary64.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ static inline int cpf_fma(double *X, const double *A, const double *B,
#define INTTYPE uint64_t
#define INTSUFFIX ULL
#define DEFPREC 53
#define DEFEMAX 1023
#define DEFEMIN -1022
#define DEFEMAX 1023
#define NLEADBITS 12
#define NBITS 64
#define FULLMASK 0xFFFFFFFFFFFFFFFFULL
Expand Down
Loading
Loading