From 16097c1fe724aa494010c175e939be0b5ca73b5e Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Thu, 19 Dec 2024 20:41:31 +0700 Subject: [PATCH] Remove the operand promotion for LHS and RHS. (#19516) Operand promotion for unaligned matmul cases is leading to dynamic trip count and forall loop fusion is not taking place by iree-codegen-gpu-fuse-and-hoist-parallel-loops. --- .../test/gpu_reorder_workgroups_static.mlir | 2 +- .../Dialect/Codegen/IR/IREECodegenAttrs.td | 22 +++--- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 68 +++++++++++++++++-- .../LLVMGPU/LLVMGPULowerExecutableTarget.cpp | 3 - .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 66 ------------------ .../iree/compiler/Codegen/LLVMGPU/Passes.h | 4 -- .../compiler/Codegen/LLVMGPU/Verifiers.cpp | 11 +-- .../Codegen/LLVMGPU/test/config_matvec.mlir | 5 +- .../test/config_root_op_attribute.mlir | 2 +- .../LLVMGPU/test/distribute_to_thread.mlir | 8 +-- .../LLVMGPU/test/gpu_set_num_workgroups.mlir | 28 +++----- .../LLVMGPU/test/illegal_configuration.mlir | 38 ----------- .../LLVMGPU/test/nvvm_pipeline_test.mlir | 31 ++++----- .../LLVMGPU/test/rocdl_pipeline_test.mlir | 13 ++-- tests/e2e/matmul/BUILD.bazel | 24 ------- tests/e2e/matmul/CMakeLists.txt | 26 ------- tests/e2e/matmul/generate_e2e_matmul_tests.py | 14 +--- 17 files changed, 106 insertions(+), 259 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_reorder_workgroups_static.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_reorder_workgroups_static.mlir index 1b7a99184dcb..992dc8ec4435 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_reorder_workgroups_static.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_reorder_workgroups_static.mlir @@ -25,7 +25,7 @@ ]> hal.executable private @main_dispatch_0 { hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { - hal.executable.export public @main_dispatch_0_matmul_transpose_b_32000x32000x4096_f16 ordinal(0) layout(#pipeline_layout) attributes {subgroup_size = 64 : index, translation_info = #iree_codegen.translation_info, workgroup_size = [64 : index, 16 : index, 1 : index]} { + hal.executable.export public @main_dispatch_0_matmul_transpose_b_32000x32000x4096_f16 ordinal(0) layout(#pipeline_layout) attributes {subgroup_size = 64 : index, translation_info = #iree_codegen.translation_info, workgroup_size = [64 : index, 16 : index, 1 : index]} { ^bb0(%arg0: !hal.device): %c250 = arith.constant 250 : index %c500 = arith.constant 500 : index diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td index 26b37dd07e24..e5c6f6f649cd 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td @@ -40,26 +40,24 @@ def LLVMGPU_SimpleDistribute : I32EnumAttrCase<"LLVMGPUDistribute", 102>; def LLVMGPU_Vectorize : I32EnumAttrCase<"LLVMGPUVectorize", 103>; -def LLVMGPU_MatmulSimt - : I32EnumAttrCase<"LLVMGPUMatmulSimt", 104>; def LLVMGPU_MatmulTensorCore - : I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 105>; + : I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 104>; def LLVMGPU_TransposeSharedMem - : I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 106>; + : I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 105>; def LLVMGPU_WarpReduction - : I32EnumAttrCase<"LLVMGPUWarpReduction", 107>; + : I32EnumAttrCase<"LLVMGPUWarpReduction", 106>; def LLVMGPU_PackUnPack - : I32EnumAttrCase<"LLVMGPUPackUnPack", 108>; + : I32EnumAttrCase<"LLVMGPUPackUnPack", 107>; def LLVMGPU_MatmulTensorCoreMmaSync - : I32EnumAttrCase<"LLVMGPUMatmulTensorCoreMmaSync", 109>; + : I32EnumAttrCase<"LLVMGPUMatmulTensorCoreMmaSync", 108>; def LLVMGPU_VectorDistribute - : I32EnumAttrCase<"LLVMGPUVectorDistribute", 110>; + : I32EnumAttrCase<"LLVMGPUVectorDistribute", 109>; def LLVMGPU_PadAndVectorDistribute - : I32EnumAttrCase<"LLVMGPUPadAndVectorDistribute", 111>; + : I32EnumAttrCase<"LLVMGPUPadAndVectorDistribute", 110>; def LLVMGPU_WinogradVectorize - : I32EnumAttrCase<"LLVMGPUWinogradVectorize", 112>; + : I32EnumAttrCase<"LLVMGPUWinogradVectorize", 111>; def LLVMGPU_TileAndFuse - : I32EnumAttrCase<"LLVMGPUTileAndFuse", 113>; + : I32EnumAttrCase<"LLVMGPUTileAndFuse", 112>; def SPIRV_BaseLowering : I32EnumAttrCase<"SPIRVBaseLowering", 200>; @@ -98,7 +96,7 @@ def DispatchLoweringPassPipelineEnum : I32EnumAttr< // LLVMGPU CodeGen pipelines LLVMGPU_Default, LLVMGPU_BaseLowering, LLVMGPU_SimpleDistribute, - LLVMGPU_Vectorize, LLVMGPU_MatmulSimt, LLVMGPU_MatmulTensorCore, + LLVMGPU_Vectorize, LLVMGPU_MatmulTensorCore, LLVMGPU_TransposeSharedMem, LLVMGPU_WarpReduction, LLVMGPU_PackUnPack, LLVMGPU_MatmulTensorCoreMmaSync, LLVMGPU_VectorDistribute, LLVMGPU_PadAndVectorDistribute, LLVMGPU_WinogradVectorize, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index fc890d1db70d..808d35644baf 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -1295,9 +1295,11 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target, CodeGenPipeline pipeline) { TileSizesListType tileSizes; unsigned numParallelLoops = op.getNumParallelLoops(); - SmallVector workgroupTileSizes(numParallelLoops - 2, 1); - workgroupTileSizes.append({tileX, tileY}); - workgroupTileSizes.append(op.getNumReductionLoops(), tileK); + unsigned numReductionLoops = op.getNumReductionLoops(); + SmallVector workgroupTileSizes( + numParallelLoops + numReductionLoops, 1); + workgroupTileSizes[numParallelLoops - 2] = tileX; + workgroupTileSizes[numParallelLoops - 1] = tileY; SmallVector partitionedLoops = cast(op.getOperation()) @@ -1311,11 +1313,63 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target, } } - tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level. std::optional subgroupSize = std::nullopt; if (!subgroupSizes.empty()) subgroupSize = subgroupSizes.front(); + // For the LLVMGPUTileAndFuse pipeline, we need to split tile sizes + // for workgroup, thread, and reduction. + if (pipeline == CodeGenPipeline::LLVMGPUTileAndFuse) { + + auto context = op.getContext(); + Builder b(context); + SmallVector attrs; + + SmallVector threadTileSizes(numParallelLoops + numReductionLoops, + 0); + std::fill(threadTileSizes.begin(), + threadTileSizes.begin() + numParallelLoops, 1); + + threadTileSizes[numParallelLoops - 2] = + (tileX / workgroupSize[0]) < 1 ? 1 : (tileX / workgroupSize[0]); + threadTileSizes[numParallelLoops - 1] = + (tileY / workgroupSize[1]) < 1 ? 1 : (tileY / workgroupSize[1]); + + SmallVector reductionTileSizes( + numParallelLoops + numReductionLoops, 0); + reductionTileSizes[numParallelLoops + numReductionLoops - 1] = tileK; + + attrs.emplace_back(b.getStringAttr("workgroup"), + b.getI64ArrayAttr(workgroupTileSizes)); + attrs.emplace_back(b.getStringAttr("thread"), + b.getI64ArrayAttr(threadTileSizes)); + attrs.emplace_back(b.getStringAttr("reduction"), + b.getI64ArrayAttr(reductionTileSizes)); + + auto configDict = b.getDictionaryAttr(attrs); + auto loweringConfig = + IREE::GPU::LoweringConfigAttr::get(context, configDict); + SmallVector pipelineAttrs; + auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get( + context, /*prefetchSharedMemory=*/false, + /*no_reduce_shared_memory_bank_conflicts=*/true, + /*use_igemm_convolution=*/false, + /*reorder_workgroups_strategy=*/std::nullopt); + pipelineAttrs.emplace_back( + b.getStringAttr(IREE::GPU::GPUPipelineOptionsAttr::getDictKeyName()), + pipelineOptions); + auto pipelineConfig = b.getDictionaryAttr(pipelineAttrs); + + return setOpConfigAndEntryPointFnTranslation( + entryPoint, op, loweringConfig, pipeline, workgroupSize, subgroupSize, + pipelineConfig); + } + + // Other pipeline (MatmulTensorCore) expect the reduction tile size to be in + // the same list. + workgroupTileSizes[numParallelLoops + numReductionLoops - 1] = tileK; + tileSizes.emplace_back(std::move(workgroupTileSizes)); + return setOpConfigAndEntryPointFnTranslation( entryPoint, op, tileSizes, pipeline, workgroupSize, subgroupSize, getSoftwarePipeliningAttrDict(op->getContext(), softwarePipelineDepth, @@ -1390,7 +1444,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target, return setMatmulConfig( sizeN, sizeM, 4, {sizeM, sizeN, 1}, target.getWgp().getSubgroupSizeChoices().asArrayRef(), - softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUMatmulSimt); + softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUTileAndFuse); } // SIMT matmul case. Query the best configuration. @@ -1404,7 +1458,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target, config.tileSize[0], config.tileSize[1], config.tileSize[2], config.workgroupSize, target.getWgp().getSubgroupSizeChoices().asArrayRef(), - softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUMatmulSimt); + softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUTileAndFuse); } } } @@ -1429,7 +1483,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target, return setMatmulConfig(tileX, tileY, tileK, workgroupSize, target.getWgp().getSubgroupSizeChoices().asArrayRef(), softwarePipelineDepthSimt, - CodeGenPipeline::LLVMGPUMatmulSimt); + CodeGenPipeline::LLVMGPUTileAndFuse); } //====---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp index 73688d2b92d5..1773e229c284 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp @@ -114,9 +114,6 @@ void LLVMGPULowerExecutableTargetPass::runOnOperation() { case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUWinogradVectorize: addGPUWinogradVectorizePassPipeline(pipeline); break; - case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt: - addGPUMatmulSimtPassPipeline(pipeline, pipelineOptions); - break; case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulTensorCore: { FailureOr maybeDepth = getSoftwarePipelineDepth(translationInfo.getConfiguration()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 1debcf3bc205..d460a1b9f56b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -526,72 +526,6 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) { funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass()); } -//===---------------------------------------------------------------------===// -// MatmulSIMT -//===---------------------------------------------------------------------===// - -void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager, - const GPUPipelineOptions &options) { - tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); - - funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); - funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); - - funcPassManager.addPass(createGPUTensorTileToSerialLoopsPass()); - funcPassManager.addPass(createGPUTensorAlloc()); - funcPassManager.addPass(createGPUTensorTilePass()); - - // Linalg -> vector - addGPUVectorizationPasses(funcPassManager); - - // tensor to memref - addBufferizePasses(funcPassManager); - - // distribute foreach threads - funcPassManager.addPass(createGPUDistributePass()); - - funcPassManager.addPass(createMemrefCopyToLinalgPass()); - funcPassManager.addPass(createGPUDistributeSharedMemoryCopyPass()); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); - - if (options.enableReduceSharedMemoryBankConflicts) { - funcPassManager.addPass(createGPUReduceBankConflictsPass()); - } - - ReorderWorkgroupsStrategy reorderStrategy = - getReorderWorkgroupsStrategy(options.reorderStrategy); - funcPassManager.addPass( - createReorderWorkgroups(reorderStrategy, canReorderWorkgroups)); - - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); - - funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass()); - funcPassManager.addPass(createCSEPass()); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); - - // Even though we vectorize before bufferization we are not able to hoist - // accumulator load/store out of the K loop until distribution. This is - // because we materialize the fill and the matmul in two different scf.forall - // regions, when they should be in the same scf.forall. Newer pipelines - // like TileAndFuse don't have this problem, because they coalesce these - // scf.forall regions into a single scf.forall. - // - // Therefore we still rely on buffer level transformations for transfer ops - // hoisting and store to load forwarding. This relies on shacky alias - // analysis and we need to move this to tensor level once we have better - // abstractions. - funcPassManager.addPass(createOptimizeVectorTransferPass()); - - // Hoist loop invariant code to avoid pipelining it. - funcPassManager.addPass(createIREELoopInvariantCodeMotionPass()); - // Pipeline memory operations. - funcPassManager.addPass(createGPUPipeliningPass()); -} - //===---------------------------------------------------------------------===// // Matmul Tensor Core //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h index caacfb2656e3..17b7b866be11 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h @@ -28,10 +28,6 @@ using IREE::GPU::GPUPipelineOptions; // LLVMGPU Backend Pass Pipelines //----------------------------------------------------------------------------// -/// Lowering using SIMT CUDA core operations. -void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager, - const GPUPipelineOptions &options); - /// Lowering using mma.sync Tensor Core operations. void addGPUMatmulTensorCoreMmaSyncPassPipeline( OpPassManager &funcPassManager, const GPUPipelineOptions &options, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp index f2e3e2da4e3f..bab5de877eb3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp @@ -38,10 +38,6 @@ getInstructionShape(Operation *op, CodeGenPipeline pipeline, Type inputElementType, SmallVector &instructionShape) { switch (pipeline) { - case CodeGenPipeline::LLVMGPUMatmulSimt: - // SIMT Pipeline / CUDA Cores - instructionShape = {1, 1, 1}; - break; case CodeGenPipeline::LLVMGPUMatmulTensorCore: // Tensor Core Pipeline / WMMA API if (inputElementType.isF16() || inputElementType.isBF16()) { @@ -81,8 +77,7 @@ verifyGPUMatmulPipeline(Operation *op, ArrayRef workgroupSize) { // This verifier only applies to matmul. CodeGenPipeline pipeline = translationInfo.getDispatchLoweringPassPipeline(); - if (pipeline != CodeGenPipeline::LLVMGPUMatmulSimt && - pipeline != CodeGenPipeline::LLVMGPUMatmulTensorCore && + if (pipeline != CodeGenPipeline::LLVMGPUMatmulTensorCore && pipeline != CodeGenPipeline::LLVMGPUMatmulTensorCoreMmaSync) { return success(); } @@ -180,10 +175,6 @@ verifyGPUMatmulPipeline(Operation *op, << pipelineName; } - // Return success for SIMT/CUDA cores. - if (pipeline == CodeGenPipeline::LLVMGPUMatmulSimt) - return success(); - // // Additional verification Tensor Core pipelines. // diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir index 1e5dbf63f2f9..3a029f2968d0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir @@ -267,12 +267,11 @@ func.func @not_vmt() { return } -// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info}> // CHECK: func.func @not_vmt() // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK: linalg.generic -// CHECK-SAME: lowering_config = #[[$CONFIG]] +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 8], thread = [1, 128, 0], workgroup = [1, 128, 1]}> // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_root_op_attribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_root_op_attribute.mlir index f3e0d81fb961..3c7e52aa475a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_root_op_attribute.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_root_op_attribute.mlir @@ -9,4 +9,4 @@ func.func @matmul(%lhs: tensor<4x4xf32>, %rhs: tensor<4x4xf32>) -> tensor<4x4xf3 return %result : tensor<4x4xf32> } -// CHECK: %2 = linalg.matmul {lowering_config = #config, root_op} ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK: %2 = linalg.matmul {lowering_config = #{{.*}}, root_op} ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir index cd69906aec13..cec55cdaf0a5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir @@ -9,7 +9,7 @@ #map = affine_map<()[s0] -> (s0 * 2)> #map1 = affine_map<()[s0] -> (s0 * 256)> #map2 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)> -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info func.func @dot_dispatch_0() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index @@ -79,7 +79,7 @@ func.func @dot_dispatch_0() attributes {translation_info = #translation} { #map2 = affine_map<(d0, d1, d2)[s0] -> (d0 * 32768 + s0 + d1 * 1024 + d2)> #map3 = affine_map<(d0, d1, d2)[s0] -> (d0 * 65536 + s0 + d1 * 64 + d2)> #map4 = affine_map<(d0, d1, d2)[s0] -> (d0 * 2048 + s0 + d1 * 64 + d2)> -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info func.func @batch_matmul_func() attributes {translation_info = #translation} { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 @@ -148,7 +148,7 @@ func.func @batch_matmul_func() attributes {translation_info = #translation} { #map = affine_map<()[s0] -> (s0 * 2)> #map1 = affine_map<()[s0] -> (s0 * 32)> #map2 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)> -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info func.func @dot_dispatch_0() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index @@ -312,7 +312,7 @@ module { #hal.pipeline.binding ]> #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #map = affine_map<()[s0] -> (s0 * 2)> #map1 = affine_map<()[s0] -> (s0 * 256)> #map2 = affine_map<(d0)[s0] -> (-d0 + s0, 2)> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir index 66fc62f2e482..642c6ed1a179 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir @@ -54,14 +54,12 @@ func.func @dot_dispatch_1() { return } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info}> // CHECK: func.func @dot_dispatch_1 // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK: linalg.fill -// CHECK-SAME: lowering_config = #[[CONFIG]] // CHECK: linalg.matmul -// CHECK-SAME: lowering_config = #[[CONFIG]] +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 4], thread = [2, 1, 0], workgroup = [4, 2, 1]}> // ----- @@ -83,14 +81,12 @@ func.func @unaligned_k() { return } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info}> // CHECK: func.func @unaligned_k // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK: linalg.fill -// CHECK-SAME: lowering_config = #[[CONFIG]] // CHECK: linalg.matmul -// CHECK-SAME: lowering_config = #[[CONFIG]] +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 2], thread = [1, 16, 0], workgroup = [32, 128, 1]}> // ----- @@ -123,7 +119,6 @@ func.func @predict_dispatch_153() { // CHECK: func.func @predict_dispatch_153() // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK: linalg.fill -// CHECK-SAME: lowering_config = #[[CONFIG]] // CHECK: linalg.generic // CHECK-SAME: lowering_config = #[[CONFIG]] @@ -254,7 +249,7 @@ func.func @static_3d_fft_stage3() { #hal.pipeline.binding ]> #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #compilation = #iree_codegen.compilation_info func.func @_lowering_config_test_dispatch_1() { %cst = arith.constant 0.000000e+00 : f32 @@ -274,11 +269,10 @@ func.func @_lowering_config_test_dispatch_1() { } // CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: func.func @_lowering_config_test_dispatch_1() // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK: linalg.fill -// CHECK-SAME: lowering_config = #[[CONFIG]] // CHECK: linalg.matmul // CHECK-SAME: lowering_config = #[[CONFIG]] @@ -341,7 +335,7 @@ func.func @matmul_config_sm35() { return } -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info}> // CHECK: func.func @matmul_config_sm35() // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -501,7 +495,6 @@ func.func @large_matmul_f16() { // SM80: func.func @large_matmul_f16() // SM80-SAME: translation_info = #[[TRANSLATION]] // SM80: linalg.fill -// SM80-SAME: lowering_config = #[[CONFIG]] // SM80: linalg.matmul // SM80-SAME: lowering_config = #[[CONFIG]] @@ -534,7 +527,6 @@ func.func @large_matmul_f32() { // SM80: func.func @large_matmul_f32() // SM80-SAME: translation_info = #[[TRANSLATION]] // SM80: linalg.fill -// SM80-SAME: lowering_config = #[[CONFIG]] // SM80: linalg.matmul // SM80-SAME: lowering_config = #[[CONFIG]] @@ -659,14 +651,12 @@ func.func @_main_dispatch_15_generic_512x4x42x42x64_f32() { return } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info}> // CHECK: func.func @_main_dispatch_15_generic_512x4x42x42x64_f32() // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK: linalg.fill -// CHECK-SAME: lowering_config = #[[CONFIG]] // CHECK: linalg.generic -// CHECK-SAME: lowering_config = #[[CONFIG]] +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 0, 0, 32], thread = [1, 1, 1, 16, 0], workgroup = [1, 1, 32, 128, 1]}> // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir index 2c3df44b325b..8dccac1fb4a6 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir @@ -1,43 +1,5 @@ // RUN: iree-opt --iree-gpu-test-target=sm_60 --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" --verify-diagnostics --split-input-file %s -#pipeline_layout = #hal.pipeline.layout, - #hal.pipeline.binding, - #hal.pipeline.binding -]> -#config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info -func.func @illegal() attributes {translation_info = #translation} { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<4x8xf32> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<8x16xf32> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref<4x16xf32> - // expected-error @+1 {{Total number of threads in a thread block 2048 exceeds the limit of 1024 with compilation pipeline LLVMGPUMatmulSimt}} - linalg.matmul {lowering_config = #config} ins(%0, %1 : memref<4x8xf32>, memref<8x16xf32>) outs(%2 : memref<4x16xf32>) - return -} - -// ----- - -#pipeline_layout = #hal.pipeline.layout, - #hal.pipeline.binding, - #hal.pipeline.binding -]> -#config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info -func.func @illegal() attributes {translation_info = #translation} { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<4x8xf32> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<8x16xf32> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref<4x16xf32> - // expected-error @+1 {{Expected workgroup size in z-dim = 1, but got 2 with compilation pipeline LLVMGPUMatmulSimt}} - linalg.matmul {lowering_config = #config} ins(%0, %1 : memref<4x8xf32>, memref<8x16xf32>) outs(%2 : memref<4x16xf32>) - return -} - -// ----- - #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir index 9cb3fed6254c..ad6aad32420c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir @@ -83,20 +83,14 @@ hal.executable @dot_dispatch_0 { } } -// CHECK-LABEL: hal.executable public @dot_dispatch_0 -// CHECK: hal.executable.variant public @cuda -// CHECK-NOT: llvm.store -// CHECK-COUNT-3: llvm.load {{.*}} : !llvm.ptr<1> -> vector<4xf32> -// CHECK: llvm.br -// CHECK-COUNT-3: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr<3> -// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf32> -// CHECK-COUNT-128: llvm.intr.fmuladd({{.*}}) : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32> -// CHECK-COUNT-3: llvm.load {{.*}} : !llvm.ptr<1> -> vector<4xf32> -// CHECK: llvm.br -// CHECK-COUNT-3: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr<3> -// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf32> -// CHECK-COUNT-128: llvm.intr.fmuladd({{.*}}) : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32> -// CHECK-COUNT-4: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr<1> +// CHECK-LABEL: hal.executable public @dot_dispatch_0 +// CHECK: hal.executable.variant public @cuda +// CHECK-NOT: llvm.store +// CHECK: llvm.br +// CHECK: llvm.load {{.*}} : !llvm.ptr<1> -> vector<32xf32> +// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<1> -> vector<16xf32> +// CHECK-COUNT-32: llvm.intr.fmuladd({{.*}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32> +// CHECK: llvm.store {{.*}} : vector<16xf32>, !llvm.ptr<1> // ----- @@ -158,11 +152,10 @@ hal.executable @dot_dispatch_0 { } // CHECK-LABEL: hal.executable public @dot_dispatch_0 -// CHECK: hal.executable.variant public @cuda -// CHECK: llvm.br -// CHECK-COUNT-8: llvm.intr.fmuladd({{.*}}) : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32> -// CHECK: llvm.br -// CHECK-COUNT-2: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr<1> +// CHECK: hal.executable.variant public @cuda +// CHECK: llvm.br +// CHECK-COUNT-32: llvm.intr.fmuladd({{.*}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32> +// CHECK: llvm.store {{.*}} : vector<16xf32>, !llvm.ptr<1> // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/rocdl_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/rocdl_pipeline_test.mlir index 578d28b027b5..2e7cd879d328 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/rocdl_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/rocdl_pipeline_test.mlir @@ -87,17 +87,12 @@ hal.executable @dot_dispatch_0 { // RDNA3-LABEL: hal.executable public @dot_dispatch_0 // RDNA3: hal.executable.variant public @rocm // RDNA3-NOT: llvm.store -// RDNA3-COUNT-3: llvm.load {{.*}} : !llvm.ptr<1> -> vector<4xf32> // RDNA3: llvm.br -// RDNA3-COUNT-3: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr<3> -// RDNA3-COUNT-32: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf32> -// RDNA3-COUNT-128: llvm.intr.fmuladd({{.*}}) : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32> -// RDNA3-COUNT-3: llvm.load {{.*}} : !llvm.ptr<1> -> vector<4xf32> +// RDNA3-COUNT-1: llvm.load {{.*}} : !llvm.ptr<1> -> vector<32xf32> +// RDNA3-COUNT-32: llvm.load {{.*}} : !llvm.ptr<1> -> vector<16xf32> +// RDNA3-COUNT-32: llvm.intr.fmuladd({{.*}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32> +// RDNA3-COUNT-1: llvm.store {{.*}} : vector<16xf32>, !llvm.ptr<1> // RDNA3: llvm.br -// RDNA3-COUNT-3: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr<3> -// RDNA3-COUNT-32: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf32> -// RDNA3-COUNT-128: llvm.intr.fmuladd({{.*}}) : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32> -// RDNA3-COUNT-4: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr<1> // ----- diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel index 0bad5e06eef7..8ffe93c0ffac 100644 --- a/tests/e2e/matmul/BUILD.bazel +++ b/tests/e2e/matmul/BUILD.bazel @@ -385,30 +385,6 @@ X86_64_AVX512_BF16 = X86_64_AVX512 + [ ## ########################################################################### -iree_generated_e2e_runner_test( - name = "e2e_matmul_cuda_f32_large_simt", - generator = ":generate_e2e_matmul_tests", - generator_args = [ - "--lhs_rhs_type=f32", - "--acc_type=f32", - "--shapes=easy_large_static", - "--compilation_info=LLVMGPUMatmulSimt", - ], - tags = [ - # CUDA cuInit fails with sanitizer on. - "noasan", - "nomsan", - "notsan", - "noubsan", - "requires-gpu-nvidia", - ], - target_backends_and_drivers = [ - ("cuda", "cuda"), - ], - test_runner = "//tools/testing/e2e:iree-e2e-matmul-test", - test_type = "matmul", -) - # Testing Ampere + TensorCore path. # WMMA TensorCore(F32): wmma.161616.f32.tf32 iree_generated_e2e_runner_test( diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt index 9e7ec415b564..b744d346ebef 100644 --- a/tests/e2e/matmul/CMakeLists.txt +++ b/tests/e2e/matmul/CMakeLists.txt @@ -1016,32 +1016,6 @@ iree_generated_e2e_runner_test( "--iree-opt-data-tiling" ) -iree_generated_e2e_runner_test( - NAME - e2e_matmul_cuda_f32_large_simt - TEST_TYPE - matmul - GENERATOR - "generate_e2e_matmul_tests.py" - GENERATOR_ARGS - "--lhs_rhs_type=f32" - "--acc_type=f32" - "--shapes=easy_large_static" - "--compilation_info=LLVMGPUMatmulSimt" - TEST_RUNNER - iree_tools_testing_e2e_iree-e2e-matmul-test - TARGET_BACKENDS - "cuda" - DRIVERS - "cuda" - LABELS - "noasan" - "nomsan" - "notsan" - "noubsan" - "requires-gpu-nvidia" -) - iree_generated_e2e_runner_test( NAME e2e_matmul_cuda_f32_large_tensorcore diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py index a97b5626c069..3061fb620af0 100644 --- a/tests/e2e/matmul/generate_e2e_matmul_tests.py +++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py @@ -50,7 +50,6 @@ class ShapesId(enum.Enum): @enum.unique class CompilationInfoId(enum.Enum): NONE = "" - LLVMGPUMatmulSimt = "LLVMGPUMatmulSimt" LLVMGPUMatmulTensorCore = "LLVMGPUMatmulTensorCore" LLVMGPUMatmulTensorCoreMmaSync = "LLVMGPUMatmulTensorCoreMmaSync" LLVMGPUVectorDistributeMFMA = "LLVMGPUVectorDistributeMFMA" @@ -461,18 +460,7 @@ def get_test_compilation_infos( software_pipeline_depth = 0 tile_workgroup_size_pairs = [] - if compilation_info_id == CompilationInfoId.LLVMGPUMatmulSimt: - tile_workgroup_size_pairs = [ - TileWorkgroupSizePair([[32, 128, 32]], [32, 8, 1]), - TileWorkgroupSizePair([[128, 64, 8]], [16, 8, 1]), - TileWorkgroupSizePair([[16, 256, 32]], [64, 2, 1]), - TileWorkgroupSizePair([[8, 32, 32]], [8, 8, 1]), - TileWorkgroupSizePair([[8, 128, 4]], [32, 1, 1]), - TileWorkgroupSizePair([[16, 64, 4]], [16, 2, 1]), - TileWorkgroupSizePair([[1, 128, 8]], [32, 1, 1]), - ] - software_pipeline_depth = 3 - elif compilation_info_id == CompilationInfoId.SPIRVCooperativeMatrixVectorize: + if compilation_info_id == CompilationInfoId.SPIRVCooperativeMatrixVectorize: tile_workgroup_size_pairs = [ TileWorkgroupSizePair( [[64, 128], [32, 64], [0, 0, 32], [16, 16, 16]], [64, 2, 1]