diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp index a4817d995b66..6a8b50369656 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp @@ -20,6 +20,52 @@ namespace mlir::iree_compiler { namespace { static constexpr int64_t kCudaWarpSize = 32; +static void +replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, Block *parent, + Value replacement, + ArrayRef availableMappingSizes) { + parent->walk([&](gpu::ThreadIdOp idOp) { + if (availableMappingSizes[static_cast(idOp.getDimension())] == 1) + rewriter.replaceAllUsesWith(idOp.getResult(), replacement); + }); +} + +// This is an upstream method adapted from +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp#L846 +// to fix the ASAN error. +DiagnosedSilenceableFailure static mapNestedForallToThreadsImpl( + RewriterBase &rewriter, Operation *target, ArrayRef blockDims, + int64_t warpSize, bool syncAfterDistribute) { + + if (blockDims.size() != 3) { + return emitDefiniteFailure(target, "requires size-3 thread mapping"); + } + + Block *parentBlock = target->getBlock(); + + // Create an early zero index value for replacements. + Location loc = target->getLoc(); + Value zero = rewriter.create(loc, 0); + DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); + WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) { + diag = mlir::transform::gpu::mapOneForallToThreadsImpl( + rewriter, std::nullopt, forallOp, blockDims, warpSize, + syncAfterDistribute); + if (diag.isDefiniteFailure()) + return WalkResult::interrupt(); + if (diag.succeeded()) + return WalkResult::skip(); + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return diag; + + // Replace ids of dimensions known to be 1 by 0 to simplify the IR. + // Here, the result of mapping determines the available mapping sizes. + replaceUnitMappingIdsHelper(rewriter, loc, parentBlock, zero, blockDims); + return DiagnosedSilenceableFailure::success(); +} + struct GPUDistributePass final : impl::GPUDistributePassBase { void runOnOperation() override { @@ -41,11 +87,24 @@ struct GPUDistributePass final int64_t subgroupSize = maybeSubgroupSize.value_or(kCudaWarpSize); rewriter.setInsertionPointToStart(&funcOp.front()); - DiagnosedSilenceableFailure result = - mlir::transform::gpu::mapNestedForallToThreadsImpl( - rewriter, std::nullopt, funcOp, workgroupSize.value(), subgroupSize, - false); - if (!result.succeeded()) + + DiagnosedSilenceableFailure result = DiagnosedSilenceableFailure::success(); + WalkResult walkResult = funcOp->walk([&](scf::ForallOp forallOp) { + bool hasWorkgroupMapping = + llvm::any_of(forallOp.getMapping().value(), + llvm::IsaPred); + if (!hasWorkgroupMapping) { + result = mapNestedForallToThreadsImpl( + rewriter, forallOp, workgroupSize.value(), subgroupSize, false); + if (result.isDefiniteFailure()) + return WalkResult::interrupt(); + if (result.succeeded()) + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) return signalPassFailure(); } }; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 3a29a8fff901..35d9f280f4e0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -204,7 +204,7 @@ static void tileAndDistributeToWorkgroup( static void tileAndBufferize(OpPassManager &funcPassManager) { ConvertToDestinationPassingStylePassOptions options; options.useWARForCooperativeMatrixCodegen = true; - tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false, + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true, /*convertToDpsOptions=*/options); addBufferizePasses(funcPassManager); } @@ -487,7 +487,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, //===---------------------------------------------------------------------===// void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true); funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); @@ -524,7 +524,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) { void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options) { - tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true); funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); @@ -725,7 +725,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline( void addGPUTransposePassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options) { - tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true); funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); @@ -969,7 +969,7 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) { } void addGPUPackUnPackPasses(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true); funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir index 28ae306f1521..c0cd53377863 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir @@ -79,7 +79,7 @@ hal.executable @mma_fused_fp16 { // CHECK: llvm.br // CHECK-NOT: nvvm.mma.sync // CHECK-COUNT-4: llvm.store {{.*}} : vector<2xf16>, !llvm.ptr<3> -// CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16> +// CHECK: llvm.load {{.*}} : !llvm.ptr<1> -> vector<16xf16> // CHECK: llvm.store {{.*}} : vector<8xf16>, !llvm.ptr // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir index 9bb0e907b8ab..9cb3fed6254c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir @@ -462,14 +462,7 @@ hal.executable @mma_fused { // SM80: nvvm.cp.async.commit.group // SM80: llvm.br // SM80-NOT: nvvm.wmma.mma -// SM80-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<3>, f32, f32, f32, f32, f32, f32, f32, f32 -// SM80: vvm.barrier0 -// SM80: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf32> -// SM80: llvm.fadd {{.*}} : vector<4xf32> -// SM80: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr<1> -// SM80: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf32> -// SM80: llvm.fadd {{.*}} : vector<4xf32> -// SM80: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr<1> +// SM80-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<1>, f32, f32, f32, f32, f32, f32, f32, f32 @@ -547,12 +540,7 @@ hal.executable @mma_fused_fp16 { // SM80: nvvm.cp.async.commit.group // SM80: llvm.br // SM80-NOT: nvvm.wmma.mma -// SM80-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> -// SM80: vvm.barrier0 -// SM80: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16> -// SM80: llvm.fadd {{.*}} : vector<8xf16> -// SM80: llvm.store {{.*}} : vector<8xf16>, !llvm.ptr<1> -// SM80: vvm.barrier0 +// SM80-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<1>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir index 8aa87740b057..23c3977c8389 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir @@ -53,7 +53,7 @@ hal.executable @transpose_dispatch_0 { // CHECK: %[[D11:.*]] = affine.apply #{{.*}}()[%[[D0]]] // CHECK: %[[D12:.*]] = vector.transfer_read %[[D3]][%[[D11]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space>, vector<4x1xf32> // CHECK: %[[D13:.*]] = vector.shape_cast %[[D12]] : vector<4x1xf32> to vector<4xf32> -// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}] +// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]] // CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] // CHECK: vector.transfer_write %[[D13]], %[[D5]][%[[D15]], %[[D16]]] {in_bounds = [true]} : vector<4xf32>, memref<4096x4096xf32, #hal.descriptor_type> @@ -116,11 +116,13 @@ hal.executable @transpose_single_operand_dispatch_0_generic_768x2048 { // CHECK: gpu.barrier // CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%[[D0]]] // CHECK: %[[D13:.*]] = vector.transfer_read %[[D3]][%[[D12]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}] -// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D17:.*]] = vector.transfer_read %[[D5]][%[[D15]], %[[D16]]], %[[CST]] {in_bounds = [true]} : memref<768x2048xf32, #hal.descriptor_type>, vector<4xf32> +// CHECK: %[[DUP_D15:.*]] = arith.addi %[[D1]], %{{.*}} : index +// CHECK: %[[DUP_D16:.*]] = arith.addi %[[D12]], %{{.*}} : index +// CHECK: %[[D17:.*]] = vector.transfer_read %[[D5]][%[[DUP_D15]], %[[DUP_D16]]], %[[CST]] {in_bounds = [true]} : memref<768x2048xf32, #hal.descriptor_type>, vector<4xf32> // CHECK: %[[D14:.*]] = vector.shape_cast %[[D13]] : vector<4x1xf32> to vector<4xf32> // CHECK: %[[D19:.*]] = arith.addf %[[D14]], %[[D17]] : vector<4xf32> +// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]] +// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] // CHECK: vector.transfer_write %[[D19]], %[[D6]][%[[D15]], %[[D16]]] {in_bounds = [true]} : vector<4xf32>, memref<768x2048xf32, #hal.descriptor_type> // ----- @@ -223,11 +225,13 @@ hal.executable @transpose_3d_yes_dispatch_0_generic_10x768x2048 { // CHECK: gpu.barrier // CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%[[D0]]] // CHECK: %[[D13:.*]] = vector.transfer_read %[[D3]][%[[C0]], %[[D12]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}] -// CHECK: %[[D17:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D18:.*]] = vector.transfer_read %[[D5]][%{{.*}}, %[[D16]], %[[D17]]], %[[CST]] {in_bounds = [true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<4xf32> +// CHECK: %[[DUP_D16:.*]] = arith.addi %[[D1]], %{{.*}} : index +// CHECK: %[[DUP_D17:.*]] = arith.addi %[[D12]], %{{.*}} : index +// CHECK: %[[D18:.*]] = vector.transfer_read %[[D5]][%{{.*}}, %[[DUP_D16]], %[[DUP_D17]]], %[[CST]] {in_bounds = [true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<4xf32> // CHECK: %[[D15:.*]] = vector.shape_cast %[[D13]] : vector<4x1xf32> to vector<4xf32> // CHECK: %[[D20:.*]] = arith.addf %[[D15]], %[[D18]] : vector<4xf32> +// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]] +// CHECK: %[[D17:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] // CHECK: vector.transfer_write %[[D20]], %[[D6]][%{{.*}}, %[[D16]], %[[D17]]] {in_bounds = [true]} : vector<4xf32>, memref<10x768x2048xf32, #hal.descriptor_type> // ----- @@ -295,7 +299,7 @@ hal.executable @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 { // CHECK: %[[D16:.*]] = vector.transfer_read %[[D3]][%[[C0]], %[[D14]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> // CHECK: %[[D17:.*]] = arith.addf %[[D15]], %[[D16]] : vector<4x1xf32> // CHECK: %[[D19:.*]] = vector.shape_cast %[[D17]] : vector<4x1xf32> to vector<4xf32> -// CHECK: %[[D21:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}] +// CHECK: %[[D21:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]] // CHECK: %[[D22:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] // CHECK: vector.transfer_write %[[D19]], %[[D7]][%{{.*}}, %[[D21]], %[[D22]]] {in_bounds = [true]} : vector<4xf32>, memref<10x2048x768xf32, #hal.descriptor_type>