Skip to content

Commit

Permalink
A few more MMAIntrinsics (iree-org#19099)
Browse files Browse the repository at this point in the history
Supporting some MMAIntrinsics that were omitted for no particular
reason, and that seemed like they could be useful. In particular,
supporting the f64 intrinsic present in CDNA2/3, adding mixed-f8-types
support in CDNA3, and adding bf16 support to CDNA2 and RDNA3.

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
  • Loading branch information
bjacob authored Nov 12, 2024
1 parent 5b9c4d9 commit 31e7343
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 8 deletions.
6 changes: 3 additions & 3 deletions compiler/plugins/target/ROCM/test/target_device_features.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647],
Expand All @@ -26,10 +26,10 @@
// GFX941-SAME: features = "+sramecc,-xnack"

// GFX940: target = #iree_gpu.target<arch = "gfx940",
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],

// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>]
// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>]
// GFX1100-SAME: subgroup_size_choices = [32, 64]

stream.executable public @reduce_dispatch {
Expand Down
63 changes: 63 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,13 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
Type f16 = Float16Type::get(context);
Type bf16 = BFloat16Type::get(context);
Type f32 = Float32Type::get(context);
Type f64 = Float64Type::get(context);
Type i8 = IntegerType::get(context, 8);
Type i32 = IntegerType::get(context, 32);
switch (intrinsic) {
case MMAIntrinsic::MFMA_F64_16x16x4_F64: {
return {f64, f64, f64};
}
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return {f32, f32, f32};
}
Expand All @@ -228,6 +232,12 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return {f16, f16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x8_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x4_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
return {bf16, bf16, f32};
}
Expand All @@ -240,6 +250,12 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: {
return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: {
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {i8, i8, i32};
}
Expand All @@ -258,6 +274,12 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {f16, f16, f16};
}
case MMAIntrinsic::WMMA_F32_16x16x16_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: {
return {bf16, bf16, bf16};
}
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return {i8, i8, i32};
}
Expand Down Expand Up @@ -505,6 +527,43 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F64_16x16x4_F64:
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16},
/*element=*/{1, 1}};
case MMAFragment::Rhs:
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{1, 1}};
case MMAFragment::Acc:
return {/*outer=*/{4, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F32_16x16x8_BF16: {
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16},
/*element=*/{1, 2}};
case MMAFragment::Rhs:
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{2, 1}};
case MMAFragment::Acc:
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{4, 1}};
}
}
case MMAIntrinsic::MFMA_F32_32x32x4_BF16:
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32},
/*element=*/{1, 2}};
case MMAFragment::Rhs:
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1},
/*element=*/{2, 1}};
case MMAFragment::Acc:
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
Expand Down Expand Up @@ -535,6 +594,8 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
switch (fragment) {
case MMAFragment::Lhs:
Expand All @@ -560,6 +621,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
/*element=*/{4, 1}};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F32_16x16x16_BF16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8:
switch (fragment) {
case MMAFragment::Lhs:
Expand All @@ -573,6 +635,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
/*element=*/{1, 1}};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_BF16_16x16x16_BF16:
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 0},
Expand Down
18 changes: 18 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,26 @@ def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x1021>;
def MFMA_I32_16x16x16_I8 : I32EnumAttrCase<"MFMA_I32_16x16x16_I8", 0x10C0>;
def MFMA_I32_32x32x8_I8 : I32EnumAttrCase<"MFMA_I32_32x32x8_I8", 0x10C1>;

// Introduced in CDNA2
def MFMA_F32_16x16x8_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x8_BF16", 0x1120>;
def MFMA_F32_32x32x4_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x4_BF16", 0x1121>;
def MFMA_F64_16x16x4_F64 : I32EnumAttrCase<"MFMA_F64_16x16x4_F64", 0x1100>;

// Introduced in CDNA3
def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x1220>;
def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x1221>;
def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x1230>;
def MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ", 0x1231>;
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x1232>;
def MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ", 0x1233>;
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x12C0>;
def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x12C1>;

// Introduced in RDNA3
def WMMA_F32_16x16x16_F16 : I32EnumAttrCase<"WMMA_F32_16x16x16_F16", 0x1820>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 0x1821>;
def WMMA_F32_16x16x16_BF16 : I32EnumAttrCase<"WMMA_F32_16x16x16_BF16", 0x1822>;
def WMMA_BF16_16x16x16_BF16 : I32EnumAttrCase<"WMMA_BF16_16x16x16_BF16", 0x1823>;
def WMMA_I32_16x16x16_I8 : I32EnumAttrCase<"WMMA_I32_16x16x16_I8", 0x18C0>;

// NV intrinsics
Expand All @@ -172,17 +181,26 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
MFMA_I32_16x16x16_I8,
MFMA_I32_32x32x8_I8,

// Introduced in CDNA2
MFMA_F32_16x16x8_BF16,
MFMA_F32_32x32x4_BF16,
MFMA_F64_16x16x4_F64,

// Introduced in CDNA3
MFMA_F32_16x16x16_BF16,
MFMA_F32_32x32x8_BF16,
MFMA_F32_16x16x32_F8E5M2FNUZ,
MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ,
MFMA_F32_16x16x32_F8E4M3FNUZ,
MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ,
MFMA_I32_16x16x32_I8,
MFMA_I32_32x32x16_I8,

// RDNA3 intrinsics
WMMA_F32_16x16x16_F16,
WMMA_F16_16x16x16_F16,
WMMA_F32_16x16x16_BF16,
WMMA_BF16_16x16x16_BF16,
WMMA_I32_16x16x16_I8,

// NV intrinsics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,19 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,

const WgpDetails *getCDNA3WgpDetails() {
static const MMAIntrinsic cdna3MMAOps[] = {
// Introduced in CDNA1, still present in CDNA3
MMAIntrinsic::MFMA_F32_16x16x4_F32,
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
// Introduced in CDNA2, still present in CDNA3
MMAIntrinsic::MFMA_F64_16x16x4_F64,
// Introduced in CDNA3
MMAIntrinsic::MFMA_F32_16x16x16_BF16,
MMAIntrinsic::MFMA_F32_32x32x8_BF16,
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ,
MMAIntrinsic::MFMA_I32_16x16x32_I8,
MMAIntrinsic::MFMA_I32_32x32x16_I8,
};
Expand All @@ -162,10 +168,16 @@ const WgpDetails *getCDNA3WgpDetails() {

const WgpDetails *getCDNA2WgpDetails() {
static const MMAIntrinsic cdna2MMAOps[] = {
// Introduced in CDNA1
MMAIntrinsic::MFMA_F32_16x16x4_F32,
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
MMAIntrinsic::MFMA_I32_16x16x16_I8,
MMAIntrinsic::MFMA_I32_32x32x8_I8,
// Introduced in CDNA2
MMAIntrinsic::MFMA_F32_16x16x8_BF16,
MMAIntrinsic::MFMA_F32_32x32x4_BF16,
MMAIntrinsic::MFMA_F64_16x16x4_F64,
};
static const WgpDetails cdna2Wgp = {allComputeBits,
allStorageBits,
Expand All @@ -183,8 +195,9 @@ const WgpDetails *getCDNA2WgpDetails() {

const WgpDetails *getCDNA1WgpDetails() {
static const MMAIntrinsic cdna1MMAOps[] = {
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16, MMAIntrinsic::MFMA_I32_16x16x16_I8,
MMAIntrinsic::MFMA_I32_32x32x8_I8,
};
static const WgpDetails cdna1Wgp = {allComputeBits,
allStorageBits,
Expand All @@ -202,9 +215,10 @@ const WgpDetails *getCDNA1WgpDetails() {

const WgpDetails *getRDNA3WgpDetails() {
static const MMAIntrinsic rdna3MMAOps[] = {
MMAIntrinsic::WMMA_F32_16x16x16_F16,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
MMAIntrinsic::WMMA_F32_16x16x16_F16, MMAIntrinsic::WMMA_F16_16x16x16_F16,
MMAIntrinsic::WMMA_I32_16x16x16_I8, MMAIntrinsic::WMMA_I32_16x16x16_I8,
MMAIntrinsic::WMMA_I32_16x16x16_I8,

};
static const WgpDetails rdna3Wgp = {allComputeBits,
allStorageBits,
Expand Down
30 changes: 30 additions & 0 deletions tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1741,6 +1741,36 @@ iree_generated_e2e_runner_test(
"requires-gpu-cdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_cdna3_dt_f64
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f64"
"--acc_type=f64"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"rocm"
DRIVERS
"hip"
COMPILER_FLAGS
${IREE_HIP_TEST_COMPILER_FLAGS}
"--iree-opt-data-tiling"
"--iree-global-opt-experimental-rocm-data-tiling"
"--iree-global-opt-enable-early-materialization=true"
"--iree-input-demote-f64-to-f32=false"
LABELS
"noasan"
"nomsan"
"notsan"
"noubsan"
"requires-gpu-cdna3"
)

endif()

elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")
Expand Down

0 comments on commit 31e7343

Please sign in to comment.