From c08362a052829f746990c8424ce0704137bcc048 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Mon, 21 Oct 2024 15:23:25 -0400 Subject: [PATCH] GPU target parameters for data tiling. (#18839) This replaces some constants what were hardcoded in GPUMaterializeEncoding.cpp by actual GPU target parameters. The logic in `getSwizzle` was doing wonky things with its own local `const int targetPreferredLoadBitWidth = 128;`, using it in a helper function inferring interleaving dimensions. That was all dating back to early days -- that was effectively trying to infer which inner-most dimensions to skip to get at the first non-Internal dimension... so that is one more thing that we can fix now that we have `TileSwizzle::Dim::Kind`. See `getInnermostNonInternalDimIdx`. The heuristic in `chooseDataTiledMMAAttr` becomes much more robust, and tested more extensively by `gpu_materialize_encoding.mlir`, now that we can pass arbitrary parameters in ad-hoc `#iree_gpu.target` attributes, see the test updates. It's unfortunately verbose (one screenful of MLIR code for each testcase) because each has to be a complete function with `flow.dispatch` ops, but that's a separate problem. --------- Signed-off-by: Benoit Jacob --- .../ROCM/test/target_device_features.mlir | 2 +- .../Common/GPU/GPUMaterializeEncoding.cpp | 118 ++++- .../GPU/test/gpu_materialize_encoding.mlir | 446 ++++++++++++++++++ .../Dialect/GPU/IR/GPUTileSwizzleUtils.cpp | 43 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.td | 12 +- .../Dialect/GPU/TargetUtils/KnownTargets.cpp | 13 +- 6 files changed, 569 insertions(+), 65 deletions(-) diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index b7a5ab68a014..726f52551744 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -18,7 +18,7 @@ // GFX942-SAME: mma = [, , , , , ], // GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], // GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, -// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>, +// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647], // MI300X: chip = > // MI300A: chip = > diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp index 5f0d60660b84..de80cc966402 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp @@ -4,6 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include #include "iree/compiler/Codegen/Common/EncodingUtils.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" @@ -54,6 +55,9 @@ static std::optional chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target, IREE::Encoding::EncodingAttr encoding) { using namespace IREE::GPU; + if (!target) { + return std::nullopt; + } MLIRContext *ctx = target.getContext(); // @@ -85,16 +89,16 @@ chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target, // Step 2: Select the unrolling factors for the generic case where there is no // narrow dimension. // - // These hardcoded constants should become functions querying `target`. - // - // Target ISA preferred load instruction size, in bits. - const int kLoadInstructionBits = 128; - // Target ISA preferred number of subgroups per block to get full utilization. - const int kNumSubgroups = 4; - // Number of register space bits to use for accumulators. Should typically be - // between 50% and 80% of total available register space, as the accumulator - // tends to be larger than the A and B matrix tiles. - const int kMaxAccumulatorRegisterBits = 256 * 32; + IREE::GPU::TargetWgpAttr wgp = target.getWgp(); + if (!wgp.getMaxLoadInstructionBits() || !wgp.getVgprSpaceBits() || + !wgp.getSimdsPerWgp()) { + // Missing workgroup parameters: data tiling not supported on this target. + return std::nullopt; + } + + auto sizeInBits = [](VectorType type) -> int { + return type.getElementTypeBitWidth() * type.getNumElements(); + }; MMAAttr intrinsicMma = MMAAttr::get(ctx, *intrinsic); auto [intrinsicA, intrinsicB, intrinsicC] = intrinsicMma.getABCVectorTypes(); @@ -102,22 +106,82 @@ chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target, // the target ISA's vector loads. For instance, if the ISA has 128-bit loads // and each intrinsic consumes only 32 bits from A and B, then we want to set // unrollK=4 to turn 4 separate 32-bit loads into one 128-bit load. - const int unrollK = - kLoadInstructionBits / - std::min( - intrinsicA.getElementTypeBitWidth() * intrinsicA.getNumElements(), - intrinsicB.getElementTypeBitWidth() * intrinsicB.getNumElements()); + int intrinsicLoadBits = + std::min(sizeInBits(intrinsicA), sizeInBits(intrinsicB)); + if (*wgp.getMaxLoadInstructionBits() % intrinsicLoadBits != 0) { + // Never seen that case: the ISA does not have a suitable load instruction + // to feed that intrinsic?! + return std::nullopt; + } + const int unrollK = *wgp.getMaxLoadInstructionBits() / intrinsicLoadBits; + // The total amount of unrolling along the M and N dimensions is normally // limited only by the number of available registers, since larger M and N // yields higher arithmetic intensity. Here, we do not yet distinguish between // plain unrolling (more instructions on each thread) and - // unrolling-to-subgroups (more threads). - const int totalUnrollMN = - kMaxAccumulatorRegisterBits / - (intrinsicC.getElementTypeBitWidth() * intrinsicC.getNumElements()); - const int totalUnrollM = static_cast( - std::floor(std::sqrt(static_cast(totalUnrollMN)))); - const int totalUnrollN = totalUnrollMN / totalUnrollM; + // unrolling-to-subgroups (more threads), since expanding to more subgroups + // correspondingly divides the available register space between this many + // subgroups, making it cancel out of the equation here. + // + // We need to solve for two variables here, unroll_m and unroll_n, constrained + // by one quadratic equation expressing that the A, B and C tiles must fit in + // VGPR space. Since we have only 1 constraint for two variables, we + // self-impose a second constraint for now: that the unrolling shape should be + // square, i.e. unrollM == unrollN. + // TODO(#18850): that is suboptimal for narrow cases. + // + // Now we have only one variable, call it x, to solve for. + + // The register space taken is: + // A-tile: x * unrollK * sizeInBits(intrinsicA) + // B-tile: x * unrollK * sizeInBits(intrinsicB) + // C-tile: x^2 * sizeInBits(intrinsicC) + // So the equation to solve is: + // x^2 * sizeInBits(intrinsicC) + // + x * unrollK * (sizeInBits(intrinsicA) + sizeInBits(intrinsicB)) + // == wgp.getVgprSpaceBits() + float c2 = sizeInBits(intrinsicC); + float c1 = unrollK * (sizeInBits(intrinsicA) + sizeInBits(intrinsicB)); + float c0 = -*wgp.getVgprSpaceBits(); // negative by construction. + // Now the equation to solve is: c2 * x^2 + c1 * x + c0 == 0. + float discriminant = c1 * c1 - 4 * c0 * c2; // positive, because c0 < 0. + // x = unique positive solution. + float x = (-c1 + std::sqrt(discriminant)) / (2 * c2); + +#ifndef NDEBUG + // Self-check quadratic solver. 10 epsilon is just a crude upper bound; + // In practice, cancellation results in check == 0 in current cases. + float check = c2 * x * x + c1 * x + c0; + assert(std::abs(check) < 10 * FLT_EPSILON * std::abs(c0)); +#endif + + // Now, looking geometrically at our unrolling space along the M and N + // dimensions, we solve the following problem in the (M,N)-plane: approximate + // a square of side length `x`, by a rectangle of side lengths `totalUnrollM` + // and `totalUnrollN`, under the constraints: + // 1. totalUnrollM * totalUnrollN <= x * x + // * Reason: by construction of x, any larger area would exceed the + // wgp.getVgprSpaceBits() budget) + // 2. totalUnrollM and totalUnrollN are powers of 2. + // * Reason: that is a self-imposed constraint for now to avoid prematurely + // entering excessing fine-tuning of unrolling factors. Also, since below + // we will put all the unroll-to-subgroups in the N dimension, that + // requires totalUnrollN to be a multiple of wgp.getSimdsPerWgp(), + // which is typically a power of 2, specifically 4. + // TODO(#18851): we will not always put all the unroll-to-subgroups on N. + // 3. totalUnrollN >= totalUnrollM. + // * Reason: Just like the previous constraint, that is also motivated by + // the code below currently putting all the unroll-to-subgroups in the N + // dimension, which requires a sufficiently large totalUnrollN. + // TODO(#18851): we will not always put all the unroll-to-subgroups on N. + // + // Set totalUnrollN = round x to nearest power of two, break ties away from 0 + // per specification of std::round. + int totalUnrollN = std::exp2(std::round(std::log2(x))); + // Based on above constraint 1: + float unroundedMaxTotalUnrollM = x * x / totalUnrollN; + int totalUnrollM = std::exp2(std::floor(std::log2(unroundedMaxTotalUnrollM))); + // Now we introduce unroll-to-subgroups. It doesn't change the overall tile // size, as it increases the number of subgroups but correspondingly decreases // the number of registers available to each subgroups. In other words, the @@ -125,16 +189,18 @@ chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target, // overall number of registers, not with how they are split between subgroups. // // For now for simplicity we put all the unroll-to-subgroups in the N - // dimension. That might be suboptimal, revisit later. That does simplify the - // below adjustments for narrow M/N, as we don't need to think about - // unroll-to-subgroups when making the narrowing adjustment. + // dimension. TODO(#18851): revisit that. + // + // That does simplify the below adjustments for narrow M/N, as we don't need + // to think about unroll-to-subgroups when making the narrowing adjustment. int unrollMToSubgroups = 1; - int unrollNToSubgroups = kNumSubgroups; + int unrollNToSubgroups = *wgp.getSimdsPerWgp(); int unrollM = totalUnrollM / unrollMToSubgroups; int unrollN = totalUnrollN / unrollNToSubgroups; // // Step 3: Adjust the unrolling factors when there is a narrow dimension. + // TODO(#18850): dealing with narrow cases as a fix-up is suboptimal. // IREE::Encoding::MatmulNarrowDim narrowDim = IREE::Encoding::getMatmulNarrowDim(encoding); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir index fd97eaf051d5..1c3a0589815d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir @@ -429,6 +429,8 @@ func.func @batch_matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32() { // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]] +// ----- + //----------------------------------------------------------------------------- // 2. MFMA_I32_16x16x32_I8 //----------------------------------------------------------------------------- @@ -622,3 +624,447 @@ func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8() { // CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]] + +// ----- + +//------------------------------------------------------------------------- +// 3. Custom target parameters to test more MaterializeEncoding heuristics. +//------------------------------------------------------------------------- + +// Custom {max_load_instruction_bits = 64} => implied default {unroll_k = 1} (omitted in output) instead of {unroll_k = 2}. + +#target_gfx942_except_max_load_instruction_bits_64 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 64, + simds_per_wgp = 4, + vgpr_space_bits = 16384 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64() attributes {hal.executable.target = #target_gfx942_except_max_load_instruction_bits_64} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- + +// Custom {max_load_instruction_bits = 256} => {unroll_k = 4} instead of {unroll_k = 2}. + +#target_gfx942_except_max_load_instruction_bits_256 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 256, + simds_per_wgp = 4, + vgpr_space_bits = 16384 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64() attributes {hal.executable.target = #target_gfx942_except_max_load_instruction_bits_256} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- + +// Custom {simds_per_wgp = 1} => implied default {unroll_n_to_subgroups = 1} (omitted in output) and {unroll_n = 8} instead of {unroll_n_to_subgroups = 4}. + +#target_gfx942_except_simds_per_wgp_1 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, + simds_per_wgp = 1, + vgpr_space_bits = 16384 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_simds_per_wgp_1() attributes {hal.executable.target = #target_gfx942_except_simds_per_wgp_1} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_simds_per_wgp_1 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- + +// Custom 2x smaller {vgpr_space_bits = 8192} => smaller unroll_m and unroll_n + +#target_gfx942_except_vgpr_space_bits_8192 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, + simds_per_wgp = 4, + vgpr_space_bits = 8192 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_8192() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_8192} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_8192 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- + +// Custom 4x smaller {vgpr_space_bits = 4096} => smaller unroll_m and unroll_n + +#target_gfx942_except_vgpr_space_bits_4096 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, + simds_per_wgp = 4, + vgpr_space_bits = 4096 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_4096() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_4096} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_4096 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- + +// Custom smaller {vgpr_space_bits = 32768} => larger unroll_m and/or unroll_n + +#target_gfx942_except_vgpr_space_bits_32768 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, + simds_per_wgp = 4, + vgpr_space_bits = 32768 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_32768() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_32768} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_32768 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp index 33ddd044d588..7ef46a6c0d9a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp @@ -144,38 +144,19 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic, return swizzle; } -// Returns the index of the dimension whose flattened size (flattening inner -// dimensions into it) matches the given `targetSize`. This is used to compute -// interleaving indices. -// -// Example: -// Input shape = [16, 8, 4, 4] -// Input targetSize = 16 -// -> Return 2, because the tail of the shape starting at index 2 is [4, 4], -// whose product equals targetSize. -static int64_t -getDimIdxForTargetSize(const TileSwizzle::ExpandShapeDimVectorType &shape, - int64_t targetSize) { - int interleaveAt = 0; - int size = 1; - for (interleaveAt = shape.size() - 1; interleaveAt >= 0; --interleaveAt) { - assert(size <= targetSize); - assert((targetSize % size) == 0); - if (size == targetSize) { - break; +static int getInnermostNonInternalDimIdx( + const TileSwizzle::ExpandShapeDimVectorType &shape) { + for (int idx = shape.size() - 1; idx >= 0; --idx) { + if (shape[idx].kind != TileSwizzle::Dim::Kind::Internal) { + return idx; } - size *= shape[interleaveAt].size; } - return interleaveAt; + assert(false && "all dimensions are internal!"); + return 0; } TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, IREE::GPU::MMAFragment fragment) { - auto [aType, bType, cType] = mma.getABCElementTypes(); - int aBits = aType.getIntOrFloatBitWidth(); - int bBits = bType.getIntOrFloatBitWidth(); - // TODO(bjacob): Should be looked up from GPU target, instead of hard-coded. - const int targetPreferredLoadBitWidth = 128; auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment); using Kind = TileSwizzle::Dim::Kind; switch (fragment) { @@ -184,9 +165,8 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, // Unroll on K with interleaving, then on M. if (mma.getUnrollK() > 1) { unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic); - int interleavingIdx = getDimIdxForTargetSize( - swizzle.expandShape[1], - targetPreferredLoadBitWidth / (mma.getUnrollK() * aBits)); + int interleavingIdx = + getInnermostNonInternalDimIdx(swizzle.expandShape[1]); interleave(swizzle, 1, interleavingIdx); } if (mma.getUnrollM() > 1) { @@ -202,9 +182,8 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, // Unroll on K with interleaving, then on N. if (mma.getUnrollK() > 1) { unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic); - int interleavingIdx = getDimIdxForTargetSize( - swizzle.expandShape[1], - targetPreferredLoadBitWidth / (mma.getUnrollK() * bBits)); + int interleavingIdx = + getInnermostNonInternalDimIdx(swizzle.expandShape[1]); interleave(swizzle, 1, interleavingIdx); } if (mma.getUnrollN() > 1) { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index 1f2cad748c08..d04e9fefe5b9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -333,11 +333,17 @@ def IREEGPU_TargetWgpAttr : AttrDef { // The maximal number of threads per X/Y/Z dimension in one workgroup. "DenseI32ArrayAttr":$max_workgroup_sizes, // The maximal number of threads we can have in one workgroup. - "uint32_t":$max_thread_count_per_workgroup, + "int32_t":$max_thread_count_per_workgroup, // The maximal number of shared memory bytes we can allocate per workgroup. - "uint32_t":$max_workgroup_memory_bytes, - // Tthe maximum number of workgroups per X/Y/Z dimension in a dispatch. + "int32_t":$max_workgroup_memory_bytes, + // The maximum number of workgroups per X/Y/Z dimension in a dispatch. "DenseI32ArrayAttr":$max_workgroup_counts, + // Max load instruction size in bits. TODO(#18849): populate on all GPUs. + OptionalParameter<"std::optional">:$max_load_instruction_bits, + // Number of SIMDs per workgroup processor. TODO(#18849): populate on all GPUs. + OptionalParameter<"std::optional">:$simds_per_wgp, + // VGPR register space size in bits. TODO(#18849): populate on all GPUs. + OptionalParameter<"std::optional">:$vgpr_space_bits, // An optional extra dict // This field allows to inject more features/limits not supported in the diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index ef04a2282c5e..4fa5074e67a4 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -43,9 +43,12 @@ struct WgpDetails { // modes. Use duplicated values if the GPU only have one subgroup size. std::array subgroupSizeChoices; std::array maxWorkgroupSizes; - uint32_t maxThreadSize; - uint32_t maxWorkgroupMemoryBytes; + int32_t maxThreadSize; + int32_t maxWorkgroupMemoryBytes; std::array maxWorkgroupCounts; + std::optional maxLoadInstructionBits; + std::optional simdsPerWgp; + std::optional vgprSpaceBits; }; // Chip level feature/limit details @@ -109,6 +112,7 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch, DenseI32ArrayAttr::get(context, wgp->maxWorkgroupSizes), wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes, DenseI32ArrayAttr::get(context, wgp->maxWorkgroupCounts), + wgp->maxLoadInstructionBits, wgp->simdsPerWgp, wgp->vgprSpaceBits, DictionaryAttr{}); TargetChipAttr targetChip; @@ -146,7 +150,10 @@ const WgpDetails *getCDNA3WgpDetails() { {1024, 1024, 1024}, 1024, 64 * 1024, - {0x7fffffff, 0x7fffffff, 0x7fffffff}}; + {0x7fffffff, 0x7fffffff, 0x7fffffff}, + /*maxLoadInstructionBits=*/128, + /*simdsPerWgp=*/4, + /*vgprSpaceBits=*/512 * 32}; return &cdna3Wgp; }