Skip to content

Commit

Permalink
PR #19096: Add F4E2M1FN and F8E8M0FNU types
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#19096

This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented.

This will enable using microscaling (MX) formats ([RFC](openxla/xla#18085)), such as MXFP4.

```c...

PiperOrigin-RevId: 709153611
  • Loading branch information
jaeyoo authored and copybara-github committed Dec 23, 2024
1 parent 56337d8 commit 16500ff
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 4 deletions.
1 change: 0 additions & 1 deletion tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,6 @@ cc_library(
deps = [
"@ml_dtypes//:float8",
"@ml_dtypes//:intn",
"@ml_dtypes//:mxfloat",
],
)

Expand Down
3 changes: 0 additions & 3 deletions tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,15 @@ limitations under the License.

#include "ml_dtypes/include/float8.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes

namespace tsl {
using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn;
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
using float8_e5m2 = ::ml_dtypes::float8_e5m2;
using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz;
using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu;

using int1 = ::ml_dtypes::int1;
using uint1 = ::ml_dtypes::uint1;
Expand Down

0 comments on commit 16500ff

Please sign in to comment.