diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index d1da29b4e986..199e24ee4b7d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -129,6 +129,7 @@ iree_compiler_cc_library( "PassUtils.cpp", "Passes.cpp", "PolynomialApproximationPass.cpp", + "PropagateReshapesByExpansion.cpp", "ReconcileTranslationInfo.cpp", "RematerializeParallelOps.cpp", "RemoveTrivialLoops.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 69fab8d15e93..8ab009f3ea94 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -120,6 +120,7 @@ iree_cc_library( "PassUtils.cpp" "Passes.cpp" "PolynomialApproximationPass.cpp" + "PropagateReshapesByExpansion.cpp" "ReconcileTranslationInfo.cpp" "RematerializeParallelOps.cpp" "RemoveTrivialLoops.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index 97bd3c5b867a..d0871c17344c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -89,6 +89,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common:VectorLayoutAnalysis", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", + "//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms", "//compiler/src/iree/compiler/Codegen/Interfaces:PartitionableLoopsInterface", "//compiler/src/iree/compiler/Codegen/Transforms", "//compiler/src/iree/compiler/Codegen/Utils", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index 38dc25f91520..29669e8ead0e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -120,6 +120,7 @@ iree_cc_library( iree::compiler::Codegen::Common::VectorLayoutAnalysis iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect + iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms iree::compiler::Codegen::Interfaces::PartitionableLoopsInterface iree::compiler::Codegen::Transforms iree::compiler::Codegen::Utils diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp index acd189046c98..1ef74e8f87d6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp @@ -8,7 +8,9 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLForwardCompat.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -65,11 +67,8 @@ collectTiledAndFusedOps(Operation *op, static LogicalResult applyTileAndFuseToEachRoot(RewriterBase &rewriter, llvm::SmallDenseSet &payloadOps, - bool threadTiling) { + IREE::GPU::TilingLevel tilingLevel) { MLIRContext *context = rewriter.getContext(); - unsigned tilingLevel = - threadTiling ? static_cast(IREE::GPU::TilingLevel::Thread) - : static_cast(IREE::GPU::TilingLevel::Reduction); for (TilingInterface tilingInterfaceOp : payloadOps) { mlir::DominanceInfo dominanceInfo(tilingInterfaceOp); @@ -87,7 +86,8 @@ applyTileAndFuseToEachRoot(RewriterBase &rewriter, rewriter.setInsertionPoint(tilingInterfaceOp); SmallVector tileSizes = getLoweringConfig(tilingInterfaceOp) - .getTilingLevelSizes(rewriter, tilingLevel, tilingInterfaceOp); + .getTilingLevelSizes(rewriter, llvm::to_underlying(tilingLevel), + tilingInterfaceOp); // Pad the tile sizes with zero. auto zero = rewriter.getIndexAttr(0); @@ -101,7 +101,8 @@ applyTileAndFuseToEachRoot(RewriterBase &rewriter, scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes); - if (threadTiling) { + if (tilingLevel == IREE::GPU::TilingLevel::Thread || + tilingLevel == IREE::GPU::TilingLevel::Subgroup) { tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); // TODO: Add some helpers to construct this based on the enum type rather @@ -112,8 +113,14 @@ applyTileAndFuseToEachRoot(RewriterBase &rewriter, if (!isConstantIntValue(size, 0)) { unsigned mappingId = static_cast(gpu::MappingId::LinearDim0) + idx++; - mapping.push_back(gpu::GPUThreadMappingAttr::get( - context, static_cast(mappingId))); + if (tilingLevel == IREE::GPU::TilingLevel::Thread) { + mapping.push_back(gpu::GPUThreadMappingAttr::get( + context, static_cast(mappingId))); + } else { + // Else it must be subgroup tiling. + mapping.push_back(gpu::GPUWarpMappingAttr::get( + context, static_cast(mappingId))); + } } } tilingOptions.setMapping(mapping); @@ -168,14 +175,13 @@ applyTileAndFuseToEachRoot(RewriterBase &rewriter, static llvm::SmallDenseSet getTiledOps(Operation *funcOp, IREE::GPU::TilingLevel tilingLevel) { llvm::SmallDenseSet targets; - unsigned opaqueLevel = static_cast(tilingLevel); + unsigned opaqueLevel = llvm::to_underlying(tilingLevel); funcOp->walk([&](TilingInterface target) { // TODO: This would probably be easier with a lowering config interface // method that checks whether a particular level is tiled. if (IREE::Codegen::LoweringConfigAttrInterface loweringConfig = getLoweringConfig(target)) { - if (!loweringConfig.getStaticTilingLevelSizes(opaqueLevel, target) - .empty()) { + if (loweringConfig.hasTilingLevel(opaqueLevel)) { targets.insert(target); } } @@ -187,7 +193,8 @@ void GPUApplyTilingLevelPass::runOnOperation() { FunctionOpInterface funcOp = getOperation(); if (tilingLevel != IREE::GPU::TilingLevel::Reduction && - tilingLevel != IREE::GPU::TilingLevel::Thread) { + tilingLevel != IREE::GPU::TilingLevel::Thread && + tilingLevel != IREE::GPU::TilingLevel::Subgroup) { funcOp.emitError() << "unsupported tiling level: " << IREE::GPU::stringifyEnum(tilingLevel) << "\n"; return signalPassFailure(); @@ -195,10 +202,9 @@ void GPUApplyTilingLevelPass::runOnOperation() { llvm::SmallDenseSet targetOps = getTiledOps(funcOp, tilingLevel); - bool useThread = tilingLevel == IREE::GPU::TilingLevel::Thread; IRRewriter rewriter(funcOp); - if (failed(applyTileAndFuseToEachRoot(rewriter, targetOps, useThread))) { + if (failed(applyTileAndFuseToEachRoot(rewriter, targetOps, tilingLevel))) { funcOp.emitError() << "tiling of level " << IREE::GPU::stringifyEnum(tilingLevel) << " failed\n"; return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp index 7d46365ace3e..a4817d995b66 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" @@ -23,6 +24,10 @@ struct GPUDistributePass final : impl::GPUDistributePassBase { void runOnOperation() override { auto funcOp = getOperation(); + IRRewriter rewriter(funcOp->getContext()); + + // First map all lane level forall loops to lanes. + IREE::GPU::mapLaneForalls(rewriter, funcOp, /*insertBarrier=*/false); std::optional> workgroupSize = getWorkgroupSize(funcOp); @@ -35,7 +40,6 @@ struct GPUDistributePass final // TODO: Don't hard code kCudaWarpSize here. int64_t subgroupSize = maybeSubgroupSize.value_or(kCudaWarpSize); - IRRewriter rewriter(funcOp->getContext()); rewriter.setInsertionPointToStart(&funcOp.front()); DiagnosedSilenceableFailure result = mlir::transform::gpu::mapNestedForallToThreadsImpl( diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index 95ef39d9768c..5bb44cd631f1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -28,7 +28,11 @@ def GPUCreateFastSlowPathPass : def GPUDistributePass : InterfacePass<"iree-codegen-gpu-distribute", "mlir::FunctionOpInterface"> { let summary = "Pass to distribute scf.forall ops."; - let dependentDialects = ["::mlir::affine::AffineDialect", "::mlir::gpu::GPUDialect"]; + let dependentDialects = [ + "::mlir::affine::AffineDialect", + "::mlir::gpu::GPUDialect", + "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect", + ]; } def GPUDistributeSharedMemoryCopyPass : @@ -152,6 +156,8 @@ def GPUApplyTilingLevelPass : clEnumValN(IREE::GPU::TilingLevel::Reduction, "reduction", "Tile and fuse all annotated ops to serial loops"), clEnumValN(IREE::GPU::TilingLevel::Thread, "thread", + "Tile and fuse all annotated ops to threads"), + clEnumValN(IREE::GPU::TilingLevel::Subgroup, "subgroup", "Tile and fuse all annotated ops to threads") )}]>, ]; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir index ab39fc2ffa89..cbdaebcd20d5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir @@ -1,7 +1,8 @@ // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level, canonicalize, cse))" %s | FileCheck %s // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=thread}, canonicalize, cse))" %s | FileCheck %s --check-prefix=THREAD +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=subgroup}, canonicalize, cse))" %s | FileCheck %s --check-prefix=SUBGROUP -#config = #iree_gpu.lowering_config<{thread = [2, 16]}> +#config = #iree_gpu.lowering_config<{thread = [2, 16], subgroup = [2, 16]}> #map = affine_map<(d0, d1) -> (d0, d1)> module { func.func @add_tensor() { @@ -35,6 +36,12 @@ module { // THREAD: scf.forall.in_parallel // THREAD: mapping = [#gpu.thread, #gpu.thread] +// SUBGROUP-LABEL: func.func @add_tensor +// SUBGROUP: scf.forall ({{.*}}) = (0, 0) to (64, 256) step (2, 16) +// SUBGROUP: linalg.generic {{.*}} ins(%{{.*}}: tensor<2x16xf32>, tensor<2x16xf32>) +// SUBGROUP: scf.forall.in_parallel +// SUBGROUP: mapping = [#gpu.warp, #gpu.warp] + // ----- #config = #iree_gpu.lowering_config<{thread = [0, 16]}> diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute.mlir index 12ad89b638d9..84f3cbf1b10f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-distribute, cse))" %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-distribute, cse))" %s --split-input-file | FileCheck %s #map = affine_map<()[s0] -> (s0 * 256)> #map1 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)> @@ -43,3 +43,49 @@ module { // CHECK: %[[B:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[OFF]]], %{{.*}} {in_bounds = [true]} : memref<1x256xf32, #{{.*}}>, vector<4xf32> // CHECK: %[[C:.*]] = arith.addf %[[A]], %[[B]] : vector<4xf32> // CHECK: vector.transfer_write %[[C]], %[[S]][%[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf32>, memref<1x4xf32, #{{.*}}> + +// ----- + +#map = affine_map<()[s0] -> (s0 * 256)> +#map1 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)> +#map2 = affine_map<(d0) -> (d0 * 4)> +#translation = #iree_codegen.translation_info +module { + func.func @add_tensor_lane_id() attributes {translation_info = #translation} { + %cst = arith.constant 0.000000e+00 : f32 + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<233x1024xf32> + memref.assume_alignment %0, 64 : memref<233x1024xf32> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<233x1024xf32> + memref.assume_alignment %1, 64 : memref<233x1024xf32> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<233x1024xf32> + memref.assume_alignment %2, 64 : memref<233x1024xf32> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %3 = affine.apply #map()[%workgroup_id_x] + %subview = memref.subview %2[%workgroup_id_y, %3] [1, 256] [1, 1] : memref<233x1024xf32> to memref<1x256xf32, #map1> + %subview_0 = memref.subview %0[%workgroup_id_y, %3] [1, 256] [1, 1] : memref<233x1024xf32> to memref<1x256xf32, #map1> + %subview_1 = memref.subview %1[%workgroup_id_y, %3] [1, 256] [1, 1] : memref<233x1024xf32> to memref<1x256xf32, #map1> + scf.forall (%arg0) in (%c64) { + %4 = affine.apply #map2(%arg0) + %subview_2 = memref.subview %subview[0, %4] [1, 4] [1, 1] : memref<1x256xf32, #map1> to memref<1x4xf32, #map1> + %5 = vector.transfer_read %subview_0[%c0, %4], %cst {in_bounds = [true]} : memref<1x256xf32, #map1>, vector<4xf32> + %6 = vector.transfer_read %subview_1[%c0, %4], %cst {in_bounds = [true]} : memref<1x256xf32, #map1>, vector<4xf32> + %7 = arith.addf %5, %6 : vector<4xf32> + vector.transfer_write %7, %subview_2[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<1x4xf32, #map1> + } {mapping = [#iree_gpu.lane_id<0>]} + return + } +} + +// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 * 4)> +// CHECK-LABEL: func.func @add_tensor_lane_id +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[TX:.*]] = gpu.lane_id +// CHECK: %[[OFF:.*]] = affine.apply #[[$MAP]](%[[TX]]) +// CHECK: %[[S:.*]] = memref.subview %{{.*}}[0, %[[OFF]]] [1, 4] [1, 1] : memref<1x256xf32, #{{.*}}> to memref<1x4xf32, #{{.*}}> +// CHECK: %[[A:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[OFF]]], %{{.*}} {in_bounds = [true]} : memref<1x256xf32, #{{.*}}>, vector<4xf32> +// CHECK: %[[B:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[OFF]]], %{{.*}} {in_bounds = [true]} : memref<1x256xf32, #{{.*}}>, vector<4xf32> +// CHECK: %[[C:.*]] = arith.addf %[[A]], %[[B]] : vector<4xf32> +// CHECK: vector.transfer_write %[[C]], %[[S]][%[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf32>, memref<1x4xf32, #{{.*}}> diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index 9a457935ee8f..e4ec4eccdecd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -242,6 +242,10 @@ std::unique_ptr> createPadDynamicAlloc(); /// Pass to convert math operations to their polynomial approximation. std::unique_ptr> createPolynomialApproximationPass(); +/// Pass to propagate reshapes by expansion through all ops without explicit +/// lowering configurations. +std::unique_ptr> createPropagateReshapesByExpansionPass(); + /// Pass to reconcile TranslationInfo across multiple functions in a dispatch /// and set the appropriate values on the surrounding HAL ops. std::unique_ptr> diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index d92fe10026f1..ed182941c372 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -434,6 +434,12 @@ def PolynomialApproximationPass : "mlir::iree_compiler::createPolynomialApproximationPass()"; } +def PropagateReshapesByExpansionPass : + Pass<"iree-codegen-propagate-reshapes-by-expansion", ""> { + let summary = "Propagates reshaping operations by expansion."; + let constructor = "mlir::iree_compiler::createPropagateReshapesByExpansionPass()"; +} + def RematerializeParallelOps : InterfacePass<"iree-codegen-rematerialize-parallel-ops", "mlir::FunctionOpInterface"> { let summary = "Pass to rematerialize and merge parallel ops into consumers."; diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp new file mode 100644 index 000000000000..860f30fa6e10 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp @@ -0,0 +1,78 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/PassDetail.h" +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler { + +namespace { + +struct PropagateReshapesByExpansionPass + : public PropagateReshapesByExpansionPassBase< + PropagateReshapesByExpansionPass> { + void runOnOperation() override; +}; +} // namespace + +void PropagateReshapesByExpansionPass::runOnOperation() { + MLIRContext *context = &getContext(); + + { + RewritePatternSet patterns(context); + // Preemptively attempt to fold any reshapes into interface bindings if + // possible to simplify subsequent reshape propagation. + populateReshapeToInterfaceTensorPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + + RewritePatternSet bubbleExpandShapePatterns(context); + linalg::ControlFusionFn bubbleUpExpansionControlFn = + [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + Operation *consumer = fusedOperand->getOwner(); + + // Block only if one of the operations has a lowering configuration + // which means it likely expects tiling specific to its original shape. + if (getLoweringConfig(producer) || getLoweringConfig(consumer)) { + return false; + } + return true; + }; + linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, + bubbleUpExpansionControlFn); + // Add patterns to do some additional cleanup (on top of canonicalizations + // that can be done later) of reshape ops. + tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); + linalg::FillOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, + context); + tensor::CollapseShapeOp::getCanonicalizationPatterns( + bubbleExpandShapePatterns, context); + tensor::EmptyOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, + context); + tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, + context); + populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns); + + if (failed(applyPatternsAndFoldGreedily( + getOperation(), std::move(bubbleExpandShapePatterns)))) { + getOperation()->emitOpError("Failed to propagate reshapes"); + return signalPassFailure(); + } +} + +std::unique_ptr> createPropagateReshapesByExpansionPass() { + return std::make_unique(); +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel index 6b5830fdc5db..295ec03d279c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel @@ -66,6 +66,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common:VectorLayoutAnalysis", "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", + "//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms", "//compiler/src/iree/compiler/Codegen/Interfaces:BufferizationInterfaces", "//compiler/src/iree/compiler/Codegen/Transforms", "//compiler/src/iree/compiler/Codegen/Utils", diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt index 0b80f3f76caf..69f2ab2df75d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt @@ -70,6 +70,7 @@ iree_cc_library( iree::compiler::Codegen::Common::GPU::CommonGPUPasses iree::compiler::Codegen::Common::VectorLayoutAnalysis iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect + iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms iree::compiler::Codegen::Interfaces::BufferizationInterfaces iree::compiler::Codegen::Transforms iree::compiler::Codegen::Utils diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index 450b6499d8ac..57f339109197 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -16,6 +16,7 @@ #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h" #include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" @@ -440,76 +441,6 @@ void transform_dialect::FlattenForallMappingOp::getEffects( transform::modifiesPayload(effects); } -//===---------------------------------------------------------------------===// -// ForallToLanesOp -//===---------------------------------------------------------------------===// - -static bool isLaneMappableForall(scf::ForallOp forallOp) { - if (forallOp.getNumResults() > 0) - return false; - if (forallOp.getRank() != 1) - return false; - if (!forallOp.getMapping().has_value()) - return false; - Attribute mapping = *forallOp.getMapping()->getValue().begin(); - if (mapping != IREE::GPU::LaneIdAttr::get(forallOp.getContext(), 0)) { - return false; - } - return true; -} - -static void rewriteForallToLanes(RewriterBase &rewriter, - scf::ForallOp forallOp) { - Location loc = forallOp->getLoc(); - assert(isLaneMappableForall(forallOp) && - "mapping non-lane mappable forall op"); - - Value laneId = rewriter.create(loc); - - // Step 4. Predicate omitted given unique topLevel scf::ForallOp. - - // Step 5. Move the body of forallOp. - // Erase the terminator first, it will not be used since we are on buffers. - rewriter.eraseOp(forallOp.getTerminator()); - rewriter.setInsertionPoint(forallOp); - rewriter.inlineBlockBefore(forallOp.getBody(), forallOp, {laneId}); - rewriter.create(loc); - - // Step 7. Erase old op. - rewriter.eraseOp(forallOp); -} - -DiagnosedSilenceableFailure transform_dialect::ForallToLanesOp::applyToOne( - transform::TransformRewriter &rewriter, mlir::FunctionOpInterface target, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { - - SmallVector foralls; - target->walk([&](scf::ForallOp forallOp) { - if (isLaneMappableForall(forallOp)) { - foralls.push_back(forallOp); - } - }); - - if (foralls.empty()) { - return mlir::emitSilenceableFailure( - target, "could not find a lane mappable scf.forall"); - } - - for (auto forall : foralls) { - rewriter.setInsertionPoint(forall); - rewriteForallToLanes(rewriter, forall); - } - - return DiagnosedSilenceableFailure::success(); -} - -void transform_dialect::ForallToLanesOp::getEffects( - SmallVectorImpl &effects) { - transform::onlyReadsHandle(getTarget(), effects); - transform::modifiesPayload(effects); -} - //===---------------------------------------------------------------------===// // ForallToWorkgroupOp //===---------------------------------------------------------------------===// @@ -1280,5 +1211,59 @@ void transform_dialect::WorkgroupSwizzleOp::getEffects( transform::modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// FuseConsumerOp +//===----------------------------------------------------------------------===// + +/// Apply fusing of consumer transformation to all payload ops and store both +/// the original consumer operation as well as the fused consumer operation. +template +static LogicalResult +applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, + Range &&payloadOps, + transform::TransformResults &transformResults) { + SmallVector originalConsumerOps; + SmallVector fusedConsumerOps; + + for (Operation *target : payloadOps) { + rewriter.setInsertionPoint(target); + + FailureOr fuseConsumerResults = + scf::tileAndFuseConsumerOfSlice(rewriter, target); + + if (failed(fuseConsumerResults)) + return failure(); + + // Report back the relevant handles to the transform op. + originalConsumerOps.push_back( + fuseConsumerResults->origConsumerOperand->getOwner()); + fusedConsumerOps.push_back( + fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner()); + } + + transformResults.set(transformOp->getOpResult(0), originalConsumerOps); + transformResults.set(transformOp->getOpResult(1), fusedConsumerOps); + return success(); +} + +DiagnosedSilenceableFailure transform_dialect::FuseConsumerOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, + transform::TransformState &state) { + LogicalResult result = + applyFuseConsumer(rewriter, getOperation(), + state.getPayloadOps(getTarget()), transformResults); + return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() + : DiagnosedSilenceableFailure::success(); +} + +void transform_dialect::FuseConsumerOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + producesHandle(getConsumer(), effects); + producesHandle(getFusedConsumer(), effects); + transform::modifiesPayload(effects); +} + #define GET_OP_CLASSES #include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.cpp.inc" diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td index 2ad0423ed952..e3a596074fbb 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td @@ -299,35 +299,6 @@ def FlattenForallMappingOp : Op, - TransformEachOpTrait, - TransformOpInterface, - ReportTrackingListenerFailuresOpTrait]> { - let description = [{ - Collect all of the scf.forall ops in the target that are distributed to - lanes. - - Only scf.forall distributed to exactly a single lane id are currently - supported. - }]; - - let arguments = (ins TransformHandleTypeInterface:$target); - let results = (outs); - - let assemblyFormat = "$target attr-dict `:` functional-type($target, results)"; - let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::transform::TransformRewriter &rewriter, - ::mlir::FunctionOpInterface target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - def ForallToWorkgroupOp : Op, @@ -755,4 +726,24 @@ def WorkgroupSwizzleOp : Op, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Fuses the consumer of the operation pointed to by the target handle + using the options provided as attributes. + }]; + let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; + + let arguments = + (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$consumer, + TransformHandleTypeInterface:$fused_consumer); + + let assemblyFormat = [{ + $target attr-dict `:` functional-type(operands, results) + }]; +} + #endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index 8dc1bdeb641c..f0e3a8e9f8ad 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -55,6 +55,7 @@ iree_lit_test_suite( "optimize_tensor_insert_extract_slices.mlir", "pad_dynamic_alloc.mlir", "polynomial_approximation.mlir", + "propagate_reshapes_by_expansion.mlir", "reconcile_translation_info.mlir", "reductions.mlir", "rematerialize_parallel_ops.mlir", @@ -66,7 +67,6 @@ iree_lit_test_suite( "tile_and_distribute_to_workgroups.mlir", "transform_buffer_opt.mlir", "transform_copy_operand.mlir", - "transform_distribute_lane_forall.mlir", "transform_flatten_forall.mlir", "transform_hoist_forall.mlir", "transform_match_partial_reduction.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index 859948043b99..6f1dd785049a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -51,6 +51,7 @@ iree_lit_test_suite( "optimize_tensor_insert_extract_slices.mlir" "pad_dynamic_alloc.mlir" "polynomial_approximation.mlir" + "propagate_reshapes_by_expansion.mlir" "reconcile_translation_info.mlir" "reductions.mlir" "rematerialize_parallel_ops.mlir" @@ -62,7 +63,6 @@ iree_lit_test_suite( "tile_and_distribute_to_workgroups.mlir" "transform_buffer_opt.mlir" "transform_copy_operand.mlir" - "transform_distribute_lane_forall.mlir" "transform_flatten_forall.mlir" "transform_hoist_forall.mlir" "transform_match_partial_reduction.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir b/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir index bd66929b01d5..1af2c7d92fe8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir @@ -79,3 +79,27 @@ func.func @dont_fold_dynamic_reshape() { flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, 12, 8], strides = [%c1, %c1, %c1] : tensor -> !flow.dispatch.tensor>{%dim2} return } + +// ----- + +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 288)> +// CHECK-LABEL: func.func @fold_reshape_slice_store +func.func @fold_reshape_slice_store(%x: index) { + // CHECK-SAME: %[[X:[A-Za-z0-9]+]]: index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor> + // CHECK: %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor> + // CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %{{.*}}, {{.*}} + %3 = flow.dispatch.tensor.load %1, offsets=[0, 0, 0, 0], sizes =[3, 3, 1, 96], strides=[1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<3x3x1x96xf32> + // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[LOAD]] : tensor<3x3x1x96xf32>) + %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<3x3x1x96xf32>) -> tensor<3x3x1x96xf32> + %5 = tensor.collapse_shape %4 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32> + // CHECK: %[[XDIV:.+]] = affine.apply #[[$MAP]]()[%[[X]]] + // CHECK: flow.dispatch.tensor.store %[[FILL]], %[[OUT]], offsets = [%[[XDIV]], 0, 0, 0], sizes = [3, 3, 1, 96] + // CHECK-SAME: tensor<3x3x1x96xf32> -> !flow.dispatch.tensor> + flow.dispatch.tensor.store %5, %2, offsets = [%x], sizes = [864], strides = [1] : tensor<864xf32> -> !flow.dispatch.tensor> + return +} diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir new file mode 100644 index 000000000000..7dd745e5a7c3 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir @@ -0,0 +1,16 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s | FileCheck %s + +func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf16>, %dest2: tensor<12xf16>) -> tensor<12xf16> { + %collapse = tensor.collapse_shape %src [[0, 1]] : tensor<3x4xf16> into tensor<12xf16> + %copy = linalg.copy ins(%collapse : tensor<12xf16>) outs(%dest: tensor<12xf16>) -> tensor<12xf16> + %copy2 = linalg.copy {lowering_config = #iree_gpu.derived_thread_config} ins(%copy : tensor<12xf16>) outs(%dest2: tensor<12xf16>) -> tensor<12xf16> + return %copy2: tensor<12xf16> +} + +// CHECK-LABEL: func @reshape_and_lowering_config +// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: tensor<3x4xf16> +// CHECK: %[[COPY1:.+]] = linalg.generic {{.*}} ins(%[[SRC]] +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[COPY1]] +// CHECK: linalg.copy +// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config +// CHECK-SAME: ins(%[[COLLAPSE]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp index fe2e29b084f6..118a781570c6 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp @@ -318,6 +318,10 @@ LoweringConfigAttr::getTilingLevelSizes(OpBuilder &builder, unsigned level, [&](int64_t t) -> OpFoldResult { return builder.getIndexAttr(t); }); } +bool LoweringConfigAttr::hasTilingLevel(unsigned level) const { + return !getTileSizeVals(level).empty(); +} + LogicalResult LoweringConfigAttr::verify(function_ref emitError, LoweringConfigTilingLevelsAttr levels, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td index 1202780c8f69..dad9349ee251 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td @@ -211,6 +211,7 @@ def IREECodegen_LoweringConfigAttr : "getWorkgroupInterchange", "getStaticTilingLevelSizes", "getTilingLevelSizes", + "hasTilingLevel", ]> ]> { let mnemonic = "lowering_config"; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td index 498e1cdbf279..ee67fab6b3ac 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td @@ -47,6 +47,19 @@ def IREECodegen_LoweringConfigAttrInterface : return ::llvm::SmallVector(); }] >, + InterfaceMethod< + /*desc=*/[{ + Returns true if the lowering config specifies tile sizes for the given + tiling level. + }], + /*retTy=*/"bool", + /*methodName=*/"hasTilingLevel", + /*args=*/(ins "unsigned":$level), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] + >, InterfaceMethod< /*desc=*/[{ Returns the tile sizes for the specified tiling level. The diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index d88dc84389cd..65e7fbacc53d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -10,6 +10,7 @@ #include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" #include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h" #include "iree/compiler/Codegen/Utils/VectorOpUtils.h" #include "llvm/ADT/STLExtras.h" @@ -1225,6 +1226,20 @@ LoweringConfigAttr::getTilingLevelSizes(OpBuilder &b, unsigned level, sizes, [&](int64_t s) -> OpFoldResult { return b.getIndexAttr(s); }); } +bool LoweringConfigAttr::hasTilingLevel(unsigned level) const { + if (level > llvm::to_underlying(GPU::TilingLevel::Lane)) { + return false; + } + return !getTileSizes(getAttributes(), static_cast(level)) + .empty(); +} + +constexpr StringLiteral kMmaKindName = "mma_kind"; + +IREE::GPU::MmaInterfaceAttr LoweringConfigAttr::getMmaKind() const { + return getAttributes().getAs(kMmaKindName); +} + //===----------------------------------------------------------------------===// // DerivedThreadConfigAttr //===----------------------------------------------------------------------===// @@ -1249,6 +1264,10 @@ DerivedThreadConfigAttr::getTilingLevelSizes(OpBuilder &b, unsigned level, sizes, [&](int64_t s) -> OpFoldResult { return b.getIndexAttr(s); }); } +bool DerivedThreadConfigAttr::hasTilingLevel(unsigned level) const { + return level == llvm::to_underlying(GPU::TilingLevel::Thread); +} + //===----------------------------------------------------------------------===// // LaneIdAttr //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index 1ab463c5792a..5a974d5100c9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -38,6 +38,7 @@ def IREEGPU_LoweringConfigAttr : "getWorkgroupTileSizes", "getStaticTilingLevelSizes", "getTilingLevelSizes", + "hasTilingLevel", ]> ]> { let mnemonic = "lowering_config"; @@ -55,7 +56,8 @@ def IREEGPU_LoweringConfigAttr : "The configured fields, including tiling levels">:$attributes ); let extraClassDeclaration = [{ - SmallVector getThreadTileSizes() const; + /// Helper to retrieve a target mma intrinsic if present. + ::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr getMmaKind() const; }]; } @@ -64,6 +66,7 @@ def IREEGPU_DerivedThreadConfig : DeclareAttrInterfaceMethods ]> { let mnemonic = "derived_thread_config"; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp index 403702599046..2058d83d2a19 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp @@ -188,6 +188,24 @@ void transform_dialect::DistributeMultiMmaOp::getEffects( transform::modifiesPayload(effects); } +//===---------------------------------------------------------------------===// +// ForallToLanesOp +//===---------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform_dialect::ForallToLanesOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + IREE::GPU::mapLaneForalls(rewriter, target, /*insertBarrier=*/true); + return DiagnosedSilenceableFailure::success(); +} + +void transform_dialect::ForallToLanesOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + //===---------------------------------------------------------------------===// // FuseForallOp //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td index 6fdbd20f6ca5..0c62c4c6a5a7 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td @@ -157,6 +157,35 @@ def DistributeMultiMmaOp : Op, + TransformEachOpTrait, + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Collect all of the scf.forall ops in the target that are distributed to + lanes. + + Only scf.forall distributed to exactly a single lane id are currently + supported. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` functional-type($target, results)"; + let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + def FuseForallOp : Op, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel index 36b25098d4a6..9e03938e0764 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel @@ -19,6 +19,7 @@ iree_lit_test_suite( srcs = enforce_glob( [ "convert_to_multi_mma.mlir", + "distribute_lane_forall.mlir", "distribute_multi_mma.mlir", "drop_multi_mma_unit_dims.mlir", "lower_multi_mma.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt index 4bd3e5d9fe64..7979adb12415 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt @@ -15,6 +15,7 @@ iree_lit_test_suite( lit SRCS "convert_to_multi_mma.mlir" + "distribute_lane_forall.mlir" "distribute_multi_mma.mlir" "drop_multi_mma_unit_dims.mlir" "lower_multi_mma.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_distribute_lane_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_lane_forall.mlir similarity index 100% rename from compiler/src/iree/compiler/Codegen/Common/test/transform_distribute_lane_forall.mlir rename to compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_lane_forall.mlir diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel index 3f40256d76fa..f20779e9e000 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel @@ -51,8 +51,10 @@ iree_gentbl_cc_library( iree_compiler_cc_library( name = "GPUTransforms", srcs = [ + "DistributeMmaToLanes.cpp", "FuseAndHoistParallelLoops.cpp", "LowerIREEGPUOps.cpp", + "PackToIntrinsics.cpp", "Passes.cpp", "Transforms.cpp", "VectorizeIREEGPUOps.cpp", @@ -64,20 +66,25 @@ iree_compiler_cc_library( ], deps = [ ":PassesIncGen", + "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", "//compiler/src/iree/compiler/Codegen/Transforms", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:DestinationStyleOpInterface", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt index 0cc85f41a97d..fd0b6e65d8bb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt @@ -45,8 +45,10 @@ iree_cc_library( "Passes.h.inc" "Transforms.h" SRCS + "DistributeMmaToLanes.cpp" "FuseAndHoistParallelLoops.cpp" "LowerIREEGPUOps.cpp" + "PackToIntrinsics.cpp" "Passes.cpp" "Transforms.cpp" "VectorizeIREEGPUOps.cpp" @@ -56,13 +58,17 @@ iree_cc_library( MLIRAffineDialect MLIRAffineUtils MLIRArithDialect + MLIRDestinationStyleOpInterface MLIRFuncDialect MLIRFunctionInterfaces MLIRGPUDialect MLIRIR MLIRLinalgDialect + MLIRLinalgTransforms + MLIRLoopLikeInterface MLIRPass MLIRSCFDialect + MLIRSCFTransforms MLIRSupport MLIRTensorDialect MLIRTransformUtils @@ -70,6 +76,7 @@ iree_cc_library( MLIRVectorDialect MLIRVectorTransforms MLIRVectorUtils + iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect iree::compiler::Codegen::Transforms PUBLIC diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeMmaToLanes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeMmaToLanes.cpp new file mode 100644 index 000000000000..ef6d5cc7459a --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeMmaToLanes.cpp @@ -0,0 +1,150 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler::IREE::GPU { + +#define GEN_PASS_DEF_DISTRIBUTEMMATOLANESPASS +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc" + +namespace { +struct DistributeMmaToLanesPass final + : impl::DistributeMmaToLanesPassBase { + void runOnOperation() override; +}; +} // namespace + +struct ConvertToMultiMma final : OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, + PatternRewriter &rewriter) const override { + auto loweringConfig = + getLoweringConfig(linalgOp); + if (!loweringConfig) { + return failure(); + } + IREE::GPU::MmaInterfaceAttr kind = loweringConfig.getMmaKind(); + if (!kind) { + return failure(); + } + if (failed(convertContractionToMultiMma(rewriter, linalgOp, kind))) { + return failure(); + } + return success(); + } +}; + +LogicalResult fuseProducersGreedily(RewriterBase &rewriter, + scf::ForallOp laneForall) { + + std::deque candidates; + laneForall->walk([&](tensor::ExtractSliceOp extractSliceOp) { + auto producer = extractSliceOp.getSource().getDefiningOp(); + if (producer && producer->getBlock() != laneForall.getBody()) { + candidates.push_back(extractSliceOp); + } + }); + + SmallVector loops = {laneForall}; + + OpBuilder::InsertionGuard g(rewriter); + while (!candidates.empty()) { + // Traverse the slices in BFS fashion. + tensor::ExtractSliceOp candidateSliceOp = candidates.front(); + candidates.pop_front(); + + // Materialize the slice of the producer in place. + std::optional fusedProducer = + scf::tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops); + if (!fusedProducer) + continue; + + // We have no way to know whether a multi-use value can be yielded from the + // parallel loop so never yield a replacement. + + // Add more fusion candidates to the worklist. + for (auto tiledOp : fusedProducer->tiledOps) { + for (OpOperand &operand : tiledOp->getOpOperands()) { + auto sliceOp = operand.get().getDefiningOp(); + if (!sliceOp) + continue; + candidates.push_back(sliceOp); + } + } + } + return success(); +} + +void DistributeMmaToLanesPass::runOnOperation() { + MLIRContext *context = &getContext(); + auto funcOp = getOperation(); + + // Step 1. Convert configured linalg ops to multi_mma. + { + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + funcOp.emitError() << "failed to convert linalg to multi_mma"; + return signalPassFailure(); + } + } + + // Step 2. Distribute multi_mma ops to lanes and greedily fuse producers. + SmallVector mmaOps; + funcOp.walk([&](IREE::GPU::MultiMmaOp mmaOp) { mmaOps.push_back(mmaOp); }); + IRRewriter rewriter(funcOp); + for (auto mmaOp : mmaOps) { + rewriter.setInsertionPoint(mmaOp); + FailureOr maybeLaneForall = + distributeMultiMmaOp(rewriter, mmaOp); + if (failed(maybeLaneForall)) { + funcOp.emitError() << "failed to distribute multi_mma ops to lanes"; + return signalPassFailure(); + } + + rewriter.setInsertionPointToStart(maybeLaneForall->getBody()); + if (failed(fuseProducersGreedily(rewriter, *maybeLaneForall))) { + funcOp.emitError() << "failed to fuse producers into lane forall"; + return signalPassFailure(); + } + } + + // Post distribution cleanup patterns. + { + RewritePatternSet patterns(context); + // Merge consecutive insert/extract slice ops to simplify later loop + // hoisting patterns. + tensor::populateFoldTensorEmptyPatterns(patterns); + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, context); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context); + scf::ForOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + funcOp.emitError() << "cleanup failed\n"; + return signalPassFailure(); + } + } +} + +} // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp index 27686572b845..895e9ee0e461 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp @@ -14,7 +14,9 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir::iree_compiler::IREE::GPU { @@ -73,8 +75,7 @@ struct FuseForalls final : OpRewritePattern { } }; -struct FuseTileableDestinationProducers final - : OpRewritePattern { +struct FuseTilableDestinationProducers final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(scf::ForallOp forallOp, PatternRewriter &rewriter) const override { @@ -113,19 +114,83 @@ struct FuseTileableDestinationProducers final } }; +struct FuseTilableForallConsumers final + : OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + LogicalResult matchAndRewrite(TilingInterface tilableOp, + PatternRewriter &rewriter) const override { + // Currently consumer fusion requires DPS, and we don't want to fuse through + // inits anyway. + auto dpsOp = dyn_cast(*tilableOp); + if (!dpsOp) { + return failure(); + } + + tensor::ParallelInsertSliceOp producerSlice; + for (auto operand : dpsOp.getDpsInputs()) { + auto forallProducer = operand.getDefiningOp(); + if (!forallProducer) { + continue; + } + Value iterArg = forallProducer.getTiedBlockArgument( + forallProducer.getTiedOpOperand(cast(operand))); + + for (auto user : iterArg.getUsers()) { + auto sliceOp = dyn_cast(user); + if (sliceOp && sliceOp.getDest() == iterArg) { + producerSlice = sliceOp; + break; + } + } + if (producerSlice) { + break; + } + } + + if (!producerSlice) { + return failure(); + } + + FailureOr fuseConsumerResults = + scf::tileAndFuseConsumerOfSlice(rewriter, producerSlice); + if (failed(fuseConsumerResults)) { + return failure(); + } + return success(); + } +}; + void FuseAndHoistParallelLoopsPass::runOnOperation() { MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - - // These two patterns are run to a fixed point, allowing fusion within - // potentially nested loops, hoisting from said loops, and continued fusion. - patterns.add(context); - patterns.add(context); - tensor::populateFoldTensorEmptyPatterns(patterns); - populateForallLoopHoistingPattern(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { - return signalPassFailure(); + + // First run the hoisting and fusion patterns. + { + RewritePatternSet patterns(context); + // These two patterns are run to a fixed point, allowing fusion within + // potentially nested loops, hoisting from said loops, and continued fusion. + patterns.add(context); + patterns.add(context); + populateForallLoopHoistingPattern(patterns); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + + // After hoisting parallel loops, try to fuse in any newly revealed consumers + // and destinations. + // TODO: Move the consumer fusion pattern to an explicit worklist rather than + // using the GreedyPatternRewriter. + { + RewritePatternSet patterns(context); + patterns.add(context); + patterns.add(context); + tensor::populateFoldTensorEmptyPatterns(patterns); + scf::ForallOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } } } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/PackToIntrinsics.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/PackToIntrinsics.cpp new file mode 100644 index 000000000000..91977fbe1d28 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/PackToIntrinsics.cpp @@ -0,0 +1,111 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler::IREE::GPU { + +#define GEN_PASS_DEF_PACKTOINTRINSICSPASS +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc" + +namespace { +struct PackToIntrinsicsPass final + : impl::PackToIntrinsicsPassBase { + void runOnOperation() override; +}; +} // namespace + +LogicalResult packToIntrinsic(linalg::LinalgOp linalgOp, + RewriterBase &rewriter) { + auto loweringConfig = + getLoweringConfig(linalgOp); + assert(loweringConfig && "Packing unconfigured op"); + + IREE::GPU::MmaInterfaceAttr kind = loweringConfig.getMmaKind(); + assert(kind && "Packing op without mma kind"); + + FailureOr contractionDims = + linalg::inferContractionDims(linalgOp); + if (failed(contractionDims)) { + return rewriter.notifyMatchFailure(linalgOp, + "failed to infer contraction dims"); + } + + if (contractionDims->m.empty() || contractionDims->n.empty() || + contractionDims->k.empty()) { + return rewriter.notifyMatchFailure( + linalgOp, "contraction like operation missing critical dimension"); + } + + auto zero = rewriter.getIndexAttr(0); + SmallVector packedSizes(linalgOp.getNumLoops(), zero); + + auto [m, n, k] = kind.getMNKShape(); + packedSizes[contractionDims->m.back()] = rewriter.getIndexAttr(m); + packedSizes[contractionDims->n.back()] = rewriter.getIndexAttr(n); + packedSizes[contractionDims->k.back()] = rewriter.getIndexAttr(k); + FailureOr maybeResult = + linalg::pack(rewriter, linalgOp, packedSizes); + if (failed(maybeResult)) { + return rewriter.notifyMatchFailure(linalgOp, "packing failed"); + } + setLoweringConfig(maybeResult->packedLinalgOp, loweringConfig); + return success(); +} + +void PackToIntrinsicsPass::runOnOperation() { + MLIRContext *context = &getContext(); + auto funcOp = getOperation(); + IRRewriter rewriter(funcOp); + SmallVector packingCandidates; + funcOp->walk([&](linalg::LinalgOp linalgOp) { + auto loweringConfig = + getLoweringConfig(linalgOp); + if (!loweringConfig) { + return; + } + if (!loweringConfig.getMmaKind()) { + return; + } + packingCandidates.push_back(linalgOp); + }); + + for (auto candidate : packingCandidates) { + rewriter.setInsertionPoint(candidate); + if (failed(packToIntrinsic(candidate, rewriter))) { + funcOp.emitError() << "failed to pack operation marked with intrinsic\n"; + return signalPassFailure(); + } + } + + // Run layout propagation patterns to pull in adjacent un-configured ops. + RewritePatternSet patterns(context); + linalg::ControlPropagationFn control = [](Operation *op) -> bool { + return !getLoweringConfig(op); + }; + + linalg::populateDataLayoutPropagationPatterns(patterns, control); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td index 7537adf928a2..d6f6c24ccde8 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td @@ -9,6 +9,17 @@ include "mlir/Pass/PassBase.td" +def DistributeMmaToLanesPass : + InterfacePass<"iree-gpu-distribute-mma-to-lanes", "mlir::FunctionOpInterface"> { + let summary = "Converts and distributes linalg ops with mma kinds to lanes"; + let dependentDialects = [ + "::mlir::arith::ArithDialect", + "::mlir::affine::AffineDialect", + "::mlir::scf::SCFDialect", + "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect", + ]; +} + def FuseAndHoistParallelLoopsPass : InterfacePass<"iree-gpu-fuse-and-hoist-parallel-loops", "mlir::FunctionOpInterface"> { let summary = "Greedily fuses and hoists parallel loops."; @@ -18,6 +29,16 @@ def FuseAndHoistParallelLoopsPass : ]; } + +def PackToIntrinsicsPass : + InterfacePass<"iree-gpu-pack-to-intrinsics", "mlir::FunctionOpInterface"> { + let summary = "Packs matmul like operations to specified intrinsic shapes"; + let dependentDialects = [ + "::mlir::tensor::TensorDialect", + "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect" + ]; +} + def VectorizeIREEGPUOpsPass : InterfacePass<"iree-gpu-vectorize-ops", "mlir::FunctionOpInterface"> { let summary = "Vectorizes then lowers a few iree_gpu ops before vectorization."; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index 3afa4ea43b48..ceb11bbb964b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -776,6 +776,54 @@ void populateIREEGPUVectorUnrollPatterns( patterns.add(patterns.getContext(), options); } +//===---------------------------------------------------------------------===// +// Resolving lane mapped forall ops +//===---------------------------------------------------------------------===// + +static bool isLaneMappableForall(scf::ForallOp forallOp) { + if (forallOp.getNumResults() > 0) + return false; + if (forallOp.getRank() != 1) + return false; + if (!forallOp.getMapping().has_value()) + return false; + Attribute mapping = *forallOp.getMapping()->getValue().begin(); + if (mapping != IREE::GPU::LaneIdAttr::get(forallOp.getContext(), 0)) { + return false; + } + return true; +} + +static void rewriteForallToLanes(RewriterBase &rewriter, scf::ForallOp forallOp, + bool insertBarrier) { + Location loc = forallOp->getLoc(); + assert(isLaneMappableForall(forallOp) && "mapping non-lane forall op"); + + Value laneId = rewriter.create(loc); + rewriter.eraseOp(forallOp.getTerminator()); + rewriter.setInsertionPoint(forallOp); + rewriter.inlineBlockBefore(forallOp.getBody(), forallOp, {laneId}); + if (insertBarrier) { + rewriter.create(loc); + } + rewriter.eraseOp(forallOp); +} + +void mapLaneForalls(RewriterBase &rewriter, Operation *funcOp, + bool insertBarrier) { + SmallVector foralls; + OpBuilder::InsertionGuard g(rewriter); + funcOp->walk([&](scf::ForallOp forallOp) { + if (isLaneMappableForall(forallOp)) { + foralls.push_back(forallOp); + } + }); + for (auto forall : foralls) { + rewriter.setInsertionPoint(forall); + rewriteForallToLanes(rewriter, forall, insertBarrier); + } +} + //===---------------------------------------------------------------------===// // ShuffleTensor Lowering //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h index 0bcc88745dd7..3f502156e3b8 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h @@ -52,6 +52,10 @@ convertContractionToMultiMma(RewriterBase &rewriter, linalg::LinalgOp linalgOp, FailureOr distributeMultiMmaOp(RewriterBase &rewriter, IREE::GPU::MultiMmaOp mmaOp); +// Helper to map all scf.forall ops on lanes. +void mapLaneForalls(RewriterBase &rewriter, Operation *funcOp, + bool insertBarrier); + // Various populate pattern methods. void populateIREEGPUDropUnitDimsPatterns(RewritePatternSet &patterns); void populateIREEGPULowerMultiMmaPatterns(RewritePatternSet &patterns); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel index cdfb9691078f..5e9c0e76f2d6 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel @@ -18,7 +18,9 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "distribute_mma_to_lanes.mlir", "fuse_and_hoist_forall.mlir", + "pack_to_intrinsics.mlir", ], include = ["*.mlir"], ), diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt index bf0669c711db..ef55e3d7faea 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt @@ -14,7 +14,9 @@ iree_lit_test_suite( NAME lit SRCS + "distribute_mma_to_lanes.mlir" "fuse_and_hoist_forall.mlir" + "pack_to_intrinsics.mlir" TOOLS FileCheck iree-opt diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir new file mode 100644 index 000000000000..dc94e64fc26b --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir @@ -0,0 +1,38 @@ +// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-distribute-mma-to-lanes, canonicalize, cse))' --split-input-file | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +module { + func.func @matmul_16x16x16(%arg0: tensor<8x2x16x16xf16>, %arg1: tensor<8x2x16x16xf16>, %arg2: tensor<2x2x16x16xf32>) -> tensor<2x2x16x16xf32> { + %empty = tensor.empty() : tensor<2x8x16x16xf16> + %lhs_transpose = linalg.transpose ins(%arg0: tensor<8x2x16x16xf16>) outs(%empty: tensor<2x8x16x16xf16>) permutation = [1, 0, 2, 3] + %mm = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} + ins(%lhs_transpose, %arg1 : tensor<2x8x16x16xf16>, tensor<8x2x16x16xf16>) + outs(%arg2 : tensor<2x2x16x16xf32>) + attrs = {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout}>} { + ^bb0(%in: f16, %in_2: f16, %out: f32): + %4 = arith.extf %in : f16 to f32 + %5 = arith.extf %in_2 : f16 to f32 + %6 = arith.mulf %4, %5 : f32 + %7 = arith.addf %out, %6 : f32 + linalg.yield %7 : f32 + } -> tensor<2x2x16x16xf32> + return %mm : tensor<2x2x16x16xf32> + } +} + +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @matmul_16x16x16 +// CHECK: scf.forall +// CHECK: %[[LHS_T:.+]] = linalg.transpose ins({{.*}}: tensor<2x8x1x4xf16>) +// CHECK: iree_gpu.multi_mma %[[LHS_T]] +// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] +// CHECK-SAME: kind = #iree_gpu.mma_layout +// CHECK-SAME: : tensor<2x8x1x4xf16>, tensor<8x2x1x4xf16> into tensor<2x2x4x1xf32> +// CHECK: mapping = [#iree_gpu.lane_id<0>] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir index c22dd406c75d..2ee4d2c00a28 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir @@ -200,3 +200,49 @@ module { // CHECK: scf.forall.in_parallel // CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]] // CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]] + +// ----- + +module { + func.func @multi_hoist_and_fuse_trailing_stuff() { + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf16> + %empty = tensor.empty() : tensor<128x128xf16> + %8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %empty) -> (tensor<128x128xf16>) { + %9 = scf.forall (%arg2, %arg3) in (2, 2) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf16>) { + %extracted_slice = tensor.extract_slice %arg4[%arg2, %arg3] [64, 64] [1, 1] : tensor<128x128xf16> to tensor<64x64xf16> + %10 = scf.forall (%arg5, %arg6) in (32, 16) shared_outs(%arg7 = %extracted_slice) -> (tensor<64x64xf16>) { + %extracted_slice_1 = tensor.extract_slice %2[%arg5, %arg6] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16> + %extracted_slice_2 = tensor.extract_slice %arg7[%arg5, %arg6] [2, 4] [1, 1] : tensor<64x64xf16> to tensor<2x4xf16> + %16 = linalg.copy ins(%extracted_slice_1 : tensor<2x4xf16>) outs(%extracted_slice_2 : tensor<2x4xf16>) -> tensor<2x4xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<64x64xf16> + } + } {mapping = [#gpu.thread, #gpu.thread]} + scf.forall.in_parallel { + tensor.parallel_insert_slice %10 into %arg4[%arg2, %arg3] [64, 64] [1, 1] : tensor<64x64xf16> into tensor<128x128xf16> + } + } {mapping = [#gpu.warp, #gpu.warp]} + scf.yield %9 : tensor<128x128xf16> + } + %transpose = linalg.transpose ins(%8: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) permutation = [1, 0] + %ceil = linalg.ceil ins(%transpose: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) -> tensor<128x128xf16> + flow.dispatch.tensor.store %ceil, %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf16> -> !flow.dispatch.tensor> + return + } +} + +// CHECK-LABEL: func @multi_hoist_and_fuse_trailing_stuff +// CHECK: scf.forall +// CHECK: scf.forall +// CHECK: %[[LOOP:.+]] = scf.for {{.*}} -> (tensor<2x4xf16>) +// CHECK: linalg.copy +// CHECK: %[[T:.+]] = linalg.transpose ins(%[[LOOP]] : tensor<2x4xf16>) +// CHECK: linalg.ceil ins(%[[T]] : tensor<4x2xf16>) {{.*}} -> tensor<4x2xf16> +// CHECK: scf.forall.in_parallel +// CHECK: scf.forall.in_parallel +// CHECK: flow.dispatch.tensor.store diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/pack_to_intrinsics.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/pack_to_intrinsics.mlir new file mode 100644 index 000000000000..1aaae4eac18c --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/pack_to_intrinsics.mlir @@ -0,0 +1,57 @@ +// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-pack-to-intrinsics, canonicalize, cse))' --split-input-file | FileCheck %s + +#config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout}> +module { + func.func @matmul_32x32x8(%a: tensor<64x64xf16>, %b: tensor<64x64xf16>, %c: tensor<64x64xf32>) -> tensor<64x64xf32> { + %mm = linalg.matmul {lowering_config = #config} ins(%a, %b : tensor<64x64xf16>, tensor<64x64xf16>) outs(%c : tensor<64x64xf32>) -> tensor<64x64xf32> + return %mm : tensor<64x64xf32> + } +} + +// CHECK-LABEL: func.func @matmul_32x32x8 +// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<64x64xf16> +// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<64x64xf16> +// CHECK-SAME: %[[C:[A-Za-z0-9]+]]: tensor<64x64xf32> +// CHECK-DAG: %[[A_PACK:.+]] = tensor.pack %[[A]] inner_dims_pos = [0, 1] inner_tiles = [32, 8] +// CHECK-DAG: %[[B_PACK:.+]] = tensor.pack %[[B]] inner_dims_pos = [1, 0] inner_tiles = [32, 8] +// CHECK-DAG: %[[C_PACK:.+]] = tensor.pack %[[C]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] +// CHECK: %[[PACKED_MM:.+]] = linalg.generic +// CHECK-SAME: ins(%[[A_PACK]], %[[B_PACK]] : tensor<2x8x32x8xf16>, tensor<8x2x32x8xf16>) +// CHECK-SAME: outs(%[[C_PACK]] : tensor<2x2x32x32xf32>) +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout}> + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d0, d3, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +module { + func.func @matmul_16x16x16(%a: tensor, %b: tensor, %c: tensor) -> tensor { + %mm = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + } ins(%a, %b : tensor, tensor) + outs(%c : tensor) attrs = { + lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout}> + } { + ^bb0(%in: f16, %in_2: f16, %out: f32): + %4 = arith.extf %in : f16 to f32 + %5 = arith.extf %in_2 : f16 to f32 + %6 = arith.mulf %4, %5 : f32 + %7 = arith.addf %out, %6 : f32 + linalg.yield %7 : f32 + } -> tensor + return %mm : tensor + } +} + +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d3, d4, d5, d7)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3, d4, d6, d7)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d6)> + +// CHECK-LABEL: func.func @matmul_16x16x16 +// CHECK: %[[PACKED_MM:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] +// CHECK-SAME: ins({{.*}} : tensor, tensor) +// CHECK-SAME: outs({{.*}} : tensor) +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout}> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 99c9c52cdff2..a417e310a909 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -293,13 +293,14 @@ void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) { //===---------------------------------------------------------------------===// void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); + // Step 1. Promote matmul operands and pack to intrinsic shapes. funcPassManager.addPass(createGPUPromoteMatmulOperandsPass()); + funcPassManager.addPass(IREE::GPU::createPackToIntrinsicsPass()); + + tileAndDistributeToWorkgroup(funcPassManager); - // Step 1. Tile and fuse tileable ops to reduction loops. + // Step 2. Tile and fuse tileable ops to reduction loops. { GPUApplyTilingLevelPassOptions options; options.tilingLevel = IREE::GPU::TilingLevel::Reduction; @@ -308,7 +309,16 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) { funcPassManager.addPass(createCSEPass()); } - // Step 2. Tile and fuse tileable ops to threads. + // Decompose pack and unpack ops and propagte the resulting reshapes. + funcPassManager.addPass( + createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/false)); + funcPassManager.addPass(createPropagateReshapesByExpansionPass()); + funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createCSEPass()); + funcPassManager.addPass(createConvertToDestinationPassingStylePass( + /*useWARForCooperativeMatrixCodegen=*/false)); + + // Step 3. Tile and fuse tileable ops to subgroups/threads. { GPUApplyTilingLevelPassOptions options; options.tilingLevel = IREE::GPU::TilingLevel::Thread; @@ -316,6 +326,12 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) { funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); } + { + GPUApplyTilingLevelPassOptions options; + options.tilingLevel = IREE::GPU::TilingLevel::Subgroup; + funcPassManager.addPass(createGPUApplyTilingLevelPass(options)); + } + funcPassManager.addPass(IREE::GPU::createDistributeMmaToLanesPass()); // Normalize loop bounds for later lowerings. funcPassManager.addPass(iree_compiler::createNormalizeLoopBoundsPass( @@ -324,30 +340,30 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) { funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createLoopInvariantCodeMotionPass()); - // Step 3. Greedily fuse parallel loops and hoist from serial loops. - // TODO: Tileable consumer fusion needs to happen here as well. + // Step 4. Greedily fuse parallel loops and hoist from serial loops. funcPassManager.addPass(IREE::GPU::createFuseAndHoistParallelLoopsPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createLoopInvariantCodeMotionPass()); - // Step 4. Lower special ops and vectorize. + // Step 5. Lower special ops and vectorize. funcPassManager.addPass(IREE::GPU::createVectorizeIREEGPUOpsPass()); addGPUVectorizationPasses(funcPassManager); + funcPassManager.addPass(createCleanupBufferAllocViewPass()); - // Step 5. Bufferize. + // Step 6. Bufferize. // TODO: This is a workaround for a bug in the lowering of // `iree_gpu.shuffle_tensor` which does not properly represent the concurrent // nature of the write to the intermediate tensor. addBufferizePasses(funcPassManager, /*allowPrivateAllocations=*/false); - // Step 6. Resolve remaining parallel loops. + // Step 7. Resolve remaining parallel loops. funcPassManager.addPass(createGPUDistributePass()); // Vectorize copies that came out of vectorization. funcPassManager.addPass(createVectorizeMemrefCopyPass()); - // Step 7. Remaining post-bufferization optimizations/lowerings. + // Step 8. Remaining post-bufferization optimizations/lowerings. funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass()); funcPassManager.addPass(createLoopInvariantCodeMotionPass()); funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 6178a9054a8a..95463a872aa2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -62,3 +62,67 @@ hal.executable public @main { // CHECK: %[[MM:.+]] = vector.contract {{.*}} %[[LHS_MM]], %[[RHS_MM]] // CHECK: scf.yield %[[MM]] // CHECK: vector.transfer_write %[[LOOP]], %[[B2]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> +#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 2], subgroup = [2, 2], mma_kind = #iree_gpu.mma_layout}> +hal.executable public @main { + hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export public @matmul_transpose_b_mfma ordinal(0) layout(#pipeline_layout) + attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @matmul_transpose_b_mfma() + attributes {translation_info = #iree_codegen.translation_info} { + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 1280], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<2048x1280xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<10240x1280xf16> + %5 = tensor.empty() : tensor<2048x10240xf32> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32> + %7 = linalg.matmul_transpose_b {lowering_config = #config} + ins(%3, %4 : tensor<2048x1280xf16>, tensor<10240x1280xf16>) + outs(%6 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240xf32> -> !flow.dispatch.tensor> + return + } + } + } +} + +// CHECK-LABEL: func @matmul_transpose_b_mfma +// CHECK-DAG: %[[B0:.+]] = hal.interface.binding.subspan set(0) binding(0) +// CHECK-DAG: %[[B1:.+]] = hal.interface.binding.subspan set(0) binding(1) +// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan set(0) binding(2) +// CHECK-DAG: memref.alloc() : memref<64x32xf16, #gpu.address_space> +// CHECK-DAG: memref.alloc() : memref<64x32xf16, #gpu.address_space> +// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) +// CHECK: gpu.barrier +// CHECK: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> +// CHECK: vector.transfer_write %[[LHS_RD]] +// CHECK: gpu.barrier +// CHECK: %[[LHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16> +// CHECK: gpu.barrier +// CHECK: %[[LHS_T:.+]] = vector.transpose %[[LHS_MM]], [0, 2, 1, 3] : vector<2x1x2x4xf16> +// CHECK: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> +// CHECK: vector.transfer_write %[[RHS_RD]] +// CHECK: gpu.barrier +// CHECK: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16> +// CHECK: gpu.barrier +// CHECK: %[[RHS_T:.+]] = vector.transpose %[[RHS_MM]], [0, 2, 1, 3] : vector<2x1x2x4xf16> +// CHECK: %[[MM:.+]] = iree_gpu.multi_mma %[[LHS_T]], %[[RHS_T]] +// CHECK: scf.yield %[[MM]] +// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 1, 3] : vector<2x2x4x1xf32> to vector<2x4x2x1xf32> +// CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]] diff --git a/compiler/src/iree/compiler/Codegen/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Transforms/BUILD.bazel index 581f6e17d06e..069b5d7e3078 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Transforms/BUILD.bazel @@ -33,6 +33,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", @@ -42,6 +43,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:ValueBoundsOpInterface", diff --git a/compiler/src/iree/compiler/Codegen/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Transforms/CMakeLists.txt index 13a05317d80e..67f5f10a823b 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Transforms/CMakeLists.txt @@ -35,6 +35,7 @@ iree_cc_library( MLIRPass MLIRSCFTransforms MLIRSupport + MLIRTensorDialect MLIRTransformUtils MLIRTransforms MLIRValueBoundsOpInterface diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp index b7a0bdb43a68..fc15464f0f9b 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/Liveness.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" @@ -22,7 +23,12 @@ #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #define DEBUG_TYPE "iree-codegen-transforms" @@ -551,16 +557,14 @@ struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern { } }; -/// Folds tensor.expand/collapse_shape into the source -/// hal.interface.binding.subspan. +/// Folds tensor.expand into the source hal.interface.binding.subspan. /// /// For example, this matches the following pattern: /// /// %subspan = hal.interface.binding.subspan ... : /// !flow.dispatch.tensor> -/// %0 = linalg.tensor_reshape %tensor [ -/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -/// ] : tensor<864xf32> into tensor<3x3x1x96xf32> +/// %0 = tensor.expand_shape %tensor [[0, 1, 2, 3]] +/// : tensor<864xf32> into tensor<3x3x1x96xf32> /// %tensor = flow.dispatch.tensor.store %0, %subspan : /// !flow.dispatch.tensor> -> /// tensor<3x3x1x96xf32> @@ -571,7 +575,7 @@ struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern { /// !flow.dispatch.tensor> /// %0 = flow.dispatch.tensor.store %tensor, %subspan : /// !flow.dispatch.tensor> -> tensor<864xf32> -struct FoldReshapeIntoInterfaceTensorStore +struct FoldExpandShapeIntoInterfaceTensorStore : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -583,17 +587,14 @@ struct FoldReshapeIntoInterfaceTensorStore !storeOp.strides().empty()) return failure(); - auto reshapeOp = storeOp.getValue().getDefiningOp(); - if (!isa(reshapeOp)) + auto reshapeOp = storeOp.getValue().getDefiningOp(); + if (!reshapeOp) { return failure(); + } // Dynamic shapes are currently unsupported. std::optional reshapeSrc = - isa(reshapeOp) - ? getStaticReshapeOpSrc( - cast(reshapeOp)) - : getStaticReshapeOpSrc( - cast(reshapeOp)); + getStaticReshapeOpSrc(reshapeOp); if (!reshapeSrc) return failure(); @@ -627,13 +628,155 @@ struct FoldReshapeIntoInterfaceTensorStore return success(); } }; + +/// Folds tensor.collapse_shape with static shape into the source +/// hal.interface.binding.subspan. The binding is currently required to be +/// static as well, however it is impossible to generate a dispatch where +/// this would not be true today. +/// +/// For example, this matches the following pattern: +/// +/// %subspan = hal.interface.binding.subspan ... : +/// !flow.dispatch.tensor> +/// %0 = tensor.collapse_shape %tensor [[0, 1, 2, 3]] +/// : tensor<3x3x1x96xf32> into tensor<864xf32> +/// %tensor = flow.dispatch.tensor.store %0, %subspan, +/// offsets = [%x], sizes = [864], strides = [1] +/// : tensor<864xf32> -> !flow.dispatch.tensor> +/// +/// And turns it into: +/// +/// %subspan = hal.interface.binding.subspan ... : +/// !flow.dispatch.tensor> +/// %0 = flow.dispatch.tensor.store %tensor, %subspan : +/// offsets = [%x * 286, 0, 0, 0], sizes = [3, 3, 1, 96] +/// strides = [1, 1, 1, 1] : tensor<3x3x1x96xf32> -> +/// !flow.dispatch.tensor> +struct FoldCollapseShapeIntoInterfaceTensorStore + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IREE::Flow::DispatchTensorStoreOp storeOp, + PatternRewriter &rewriter) const override { + // Bail out if the strides aren't unit. + if (!llvm::all_of(storeOp.getMixedStrides(), [](OpFoldResult s) { + return isConstantIntValue(s, 1); + })) { + return failure(); + } + + auto collapseShape = + storeOp.getValue().getDefiningOp(); + // TODO: Support dynamic shapes. + if (!collapseShape || !collapseShape.getSrcType().hasStaticShape()) { + return failure(); + } + + auto subspanOp = + storeOp.getTarget() + .template getDefiningOp(); + // TODO: Support dynamic dims. + if (!subspanOp || !subspanOp.getDynamicDims().empty()) { + return failure(); + } + + auto subspanType = + llvm::cast(subspanOp.getType()); + + ArrayRef reshapeSrcShape = collapseShape.getSrcType().getShape(); + + // Verify the subspan shape against the shape of the slice being inserted. + for (auto [size, group] : llvm::zip_equal( + subspanType.getShape(), collapseShape.getReassociationIndices())) { + if (group.size() == 1) { + continue; + } + + int64_t innerDimSize = 1; + for (auto i : llvm::drop_begin(group)) { + innerDimSize *= reshapeSrcShape[i]; + } + if (size % innerDimSize != 0) { + return rewriter.notifyMatchFailure( + storeOp, "Subspan type indivisible by expanded shape"); + } + } + + AffineExpr d0, d1; + bindDims(rewriter.getContext(), d0, d1); + AffineExpr div = d0.ceilDiv(d1); + + Location loc = collapseShape.getLoc(); + SmallVector expandedSubspanShape; + SmallVector expandedOffsets; + SmallVector expandedSizes; + OpFoldResult zero = rewriter.getIndexAttr(0); + for (auto [size, group, offset] : llvm::zip_equal( + subspanType.getShape(), collapseShape.getReassociationIndices(), + storeOp.getMixedOffsets())) { + expandedSizes.push_back(rewriter.getIndexAttr(reshapeSrcShape[group[0]])); + + // Special case for 1 to avoid going through arith folders. + if (group.size() == 1) { + expandedOffsets.push_back(offset); + expandedSubspanShape.push_back(size); + continue; + } + + int64_t innerDimSize = 1; + for (auto i : llvm::drop_begin(group)) { + innerDimSize *= reshapeSrcShape[i]; + } + OpFoldResult innerDimSizeAttr = rewriter.getIndexAttr(innerDimSize); + expandedOffsets.push_back(affine::makeComposedFoldedAffineApply( + rewriter, loc, div, {offset, innerDimSizeAttr})); + assert(size % innerDimSize == 0); + expandedSubspanShape.push_back(size / innerDimSize); + for (auto i : llvm::drop_begin(group)) { + expandedOffsets.push_back(zero); + int64_t dimSize = reshapeSrcShape[i]; + expandedSubspanShape.push_back(dimSize); + expandedSizes.push_back(rewriter.getIndexAttr(dimSize)); + } + } + + auto newSubspanTensorType = RankedTensorType::get( + expandedSubspanShape, collapseShape.getSrcType().getElementType()); + auto newSubspanType = IREE::Flow::DispatchTensorType::get( + subspanType.getAccess(), newSubspanTensorType); + + Value newSubspanOp; + { + // NOTE: If there were any dynamic dims, they would need to be updated + // based on the newly introduced static sizes as well. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(subspanOp); + newSubspanOp = rewriter.create( + subspanOp.getLoc(), newSubspanType, subspanOp.getSet(), + subspanOp.getBinding(), subspanOp.getDescriptorType(), + subspanOp.getByteOffset(), subspanOp.getDynamicDims(), + subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr()); + } + + SmallVector expandedStrides(reshapeSrcShape.size(), + rewriter.getIndexAttr(1)); + rewriter.replaceOpWithNewOp( + storeOp, collapseShape.getSrc(), newSubspanOp, storeOp.getTargetDims(), + expandedOffsets, expandedSizes, expandedStrides); + return success(); + } +}; + } // namespace void populateReshapeToInterfaceTensorPatterns(RewritePatternSet &patterns) { patterns.insert, FoldReshapeIntoInterfaceTensorLoad>( patterns.getContext()); - patterns.insert(patterns.getContext()); + patterns.insert( + patterns.getContext()); + patterns.insert( + patterns.getContext()); } //===--------------------------------------------------------------------====//