From 6c75aa1083d6f9a1fa7f2b1ddd032decc9e87aa7 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 27 May 2024 17:58:06 -0400 Subject: [PATCH] [Codegen][GPU] Allow iree_gpu.tensor_barrier to take vectors (#17479) This allows synchronizing on vectors as well as tensors with similar semantics. In a typical lowering flow, this will represent the read equivalent to a tensor barrier, in that a tensor barrier represents a wait until all writes to a shared allocation has finished, while this represents a wait until all threads have read the value they need from that shared allocation. Renames the operation to iree_gpu.value_barrier for clarity. --- .../test/iree_comprehensive_bufferize.mlir | 24 +++++++++++++++-- .../Codegen/Dialect/GPU/IR/IREEGPUOps.cpp | 8 ------ .../Codegen/Dialect/GPU/IR/IREEGPUOps.td | 25 ++++++++++-------- .../Dialect/GPU/IR/test/iree_gpu_ops.mlir | 15 +++++++++-- .../TransformExtensions/IREEGPUExtensions.cpp | 9 +++++++ .../IREEGPUExtensionsOps.td | 13 ++++++++++ .../GPU/TransformExtensions/test/BUILD.bazel | 1 + .../TransformExtensions/test/CMakeLists.txt | 1 + .../test/lower_vector_barrier.mlir | 21 +++++++++++++++ .../Transforms/BufferizationInterfaces.cpp | 18 ++++++++----- .../Dialect/GPU/Transforms/Transforms.cpp | 26 +++++++++++++++++-- .../Dialect/GPU/Transforms/Transforms.h | 2 ++ 12 files changed, 132 insertions(+), 31 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_vector_barrier.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir index b6dde8ffc982..911a44df4f7f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir @@ -2580,7 +2580,7 @@ func.func @tensor_barrier() -> vector<2xf32> { %c0 = arith.constant 0 : index %alloc = bufferization.alloc_tensor() : tensor<2xf32> %tmp = vector.transfer_write %cst, %alloc[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32> - %barrier = iree_gpu.tensor_barrier %tmp : tensor<2xf32> + %barrier = iree_gpu.value_barrier %tmp : tensor<2xf32> %res = vector.transfer_read %barrier[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32> return %res : vector<2xf32> } @@ -2601,7 +2601,7 @@ func.func @tensor_barrier_in_loop() -> vector<2xf32> { %alloc = bufferization.alloc_tensor() : tensor<2xf32> %loop = scf.for %arg0 = %c0 to %c10 step %c1 iter_args(%init = %alloc) -> tensor<2xf32> { %tmp = vector.transfer_write %cst, %init[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32> - %barrier = iree_gpu.tensor_barrier %tmp : tensor<2xf32> + %barrier = iree_gpu.value_barrier %tmp : tensor<2xf32> scf.yield %barrier : tensor<2xf32> } %res = vector.transfer_read %loop[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32> @@ -2614,3 +2614,23 @@ func.func @tensor_barrier_in_loop() -> vector<2xf32> { // CHECK-NEXT: gpu.barrier // CHECK-NEXT: } // CHECK: vector.transfer_read %[[ALLOC]] + +// ----- + +func.func @vector_barrier() -> vector<2xf32> { + %cst = arith.constant dense<0.0> : vector<2xf32> + %cst0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %alloc = bufferization.alloc_tensor() : tensor<2xf32> + %tmp = vector.transfer_write %cst, %alloc[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32> + %read = vector.transfer_read %tmp[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32> + %barrier = iree_gpu.value_barrier %read : vector<2xf32> + return %barrier : vector<2xf32> +} + +// Verify that the dual-modes of `value_barrier` are adhered to. +// CHECK-LABEL: func @vector_barrier() +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<2xf32> +// CHECK: vector.transfer_write %{{.*}}, %[[ALLOC]] +// CHECK-NEXT: %[[RD:.+]] = vector.transfer_read %[[ALLOC]] +// CHECK-NEXT: iree_gpu.value_barrier %[[RD]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp index 3950cccaf58e..2e106e120706 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp @@ -237,12 +237,4 @@ LogicalResult ShuffleTensorOp::verifyRegions() { return success(); } -//===----------------------------------------------------------------------===// -// TensorBarrierOp -//===----------------------------------------------------------------------===// - -MutableOperandRange TensorBarrierOp::getDpsInitsMutable() { - return getInputMutable(); -} - } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td index 4c360a78f649..6b18e5cb4e6e 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td @@ -352,36 +352,39 @@ def IREEGPU_ShuffleTensorOp : Op, AllTypesMatch<["input", "result"]>, ]> { let summary = "Shuffles a private tensor across a shared allocation"; let description = [{ - This operation acts as a barrier on a tensor value. It takes a single - tensor operand and produces an equivalent tensor. This does not have copy - and/or data movement semantics and simply represents a barrier on all writes - to the input tensor. + This operation acts as a barrier on a value semantic SSA value (tensor or + vector). It takes a single operand and produces a value equivalent to the + input. This does not have copy and/or data movement semantics and simply + represents a barrier on all writes in the tensor case, and a barrier until + all threads acquire the input vector in the vector case. This operation is a no-op when not present in a parallel context. This operation is pure as it only requires synchronization for the value it produces. }]; - let arguments = (ins AnyRankedTensor:$input); - let results = (outs AnyRankedTensor:$result); + let arguments = (ins AnyRankedTensorOrVector:$input); + let results = (outs AnyRankedTensorOrVector:$result); let assemblyFormat = [{ $input attr-dict `:` type($result) }]; let extraClassDeclaration = [{ - RankedTensorType getInputType() { - return getInput().getType(); + bool hasTensorSemantics() { + return isa<::mlir::RankedTensorType>(getInput().getType()); + } + ::mlir::ShapedType getInputType() { + return ::llvm::cast<::mlir::ShapedType>(getInput().getType()); } }]; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir index 08b2f17eece1..e2744fb530a0 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir @@ -134,10 +134,21 @@ func.func @single_multi_mma(%lhs: vector<4xf16>, %rhs: vector<4xf16>, %acc: vect // ----- func.func @tensor_barrier(%input: tensor) -> tensor { - %out = iree_gpu.tensor_barrier %input : tensor + %out = iree_gpu.value_barrier %input : tensor return %out : tensor } // CHECK-LABEL: func @tensor_barrier // CHECK-SAME: %[[INPUT:[A-Za-z0-9]+]]: tensor -// CHECK: iree_gpu.tensor_barrier %[[INPUT]] : tensor +// CHECK: iree_gpu.value_barrier %[[INPUT]] : tensor + +// ----- + +func.func @vector_barrier(%input: vector<8xf16>) -> vector<8xf16> { + %out = iree_gpu.value_barrier %input : vector<8xf16> + return %out : vector<8xf16> +} + +// CHECK-LABEL: func @vector_barrier +// CHECK-SAME: %[[INPUT:[A-Za-z0-9]+]]: vector<8xf16> +// CHECK: iree_gpu.value_barrier %[[INPUT]] : vector<8xf16> diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp index f8cb914b41ea..8dcabd19d9d0 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp @@ -25,6 +25,15 @@ transform_dialect::IREEGPUExtensions::IREEGPUExtensions() { >(); } +//===---------------------------------------------------------------------===// +// ApplyLowerValueBarrierOp +//===---------------------------------------------------------------------===// + +void transform_dialect::ApplyLowerValueBarrierOp::populatePatterns( + RewritePatternSet &patterns) { + IREE::GPU::populateIREEGPULowerValueBarrierPatterns(patterns); +} + //===---------------------------------------------------------------------===// // ApplyUnrollMultiMmaOp //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td index 08e13ecbc79f..dc69083136ca 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td @@ -14,6 +14,19 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def ApplyLowerValueBarrierOp : Op, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Populate patterns to convert value barriers on vectors into gpu.barrier ops. + Barriers on tensors are ignored. + }]; + + let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; + let assemblyFormat = "attr-dict"; +} + def ApplyUnrollMultiMmaOp : Op, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel index 11c0c8ee9756..f78f6d01f6de 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel @@ -18,6 +18,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "lower_vector_barrier.mlir", "transform_fuse_forall.mlir", "vectorize_multi_mma.mlir", "unroll_multi_mma.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt index 686ce93a655a..f3e2e404284f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "lower_vector_barrier.mlir" "transform_fuse_forall.mlir" "unroll_multi_mma.mlir" "vectorize_multi_mma.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_vector_barrier.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_vector_barrier.mlir new file mode 100644 index 000000000000..b8391a3682da --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_vector_barrier.mlir @@ -0,0 +1,21 @@ +// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s + +func.func @lower_value_barrier(%input: vector<4xf32>) -> vector<4xf32> { + %0 = iree_gpu.value_barrier %input : vector<4xf32> + return %0 : vector<4xf32> +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.iree.lower_value_barrier + } : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @lower_value_barrier +// CHECK-SAME: %[[INPUT:[A-Za-z0-9]+]]: vector<4xf32> +// CHECK-NEXT: gpu.barrier +// CHECK-NEXT: return %[[INPUT]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp index 1071d913c083..4dc7621a0b05 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp @@ -24,9 +24,9 @@ namespace { /// Bufferization of iree_gpu.tensor_barrier. Always just bufferizes in place /// and replaces with a barrier. -struct TensorBarrierOpBufferizationInterface +struct ValueBarrierOpBufferizationInterface : public BufferizableOpInterface::ExternalModel< - TensorBarrierOpBufferizationInterface, IREE::GPU::TensorBarrierOp> { + ValueBarrierOpBufferizationInterface, IREE::GPU::ValueBarrierOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // This op never needs to bufferize to a copy. @@ -47,8 +47,11 @@ struct TensorBarrierOpBufferizationInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { - auto barrierOp = cast(op); + auto barrierOp = cast(op); assert(value == barrierOp.getResult() && "invalid value"); + if (!barrierOp.hasTensorSemantics()) { + return failure(); + } auto srcMemrefType = bufferization::getBufferType(barrierOp.getInput(), options, invocationStack); if (failed(srcMemrefType)) @@ -58,7 +61,10 @@ struct TensorBarrierOpBufferizationInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { - auto barrierOp = cast(op); + auto barrierOp = cast(op); + if (!barrierOp.hasTensorSemantics()) { + return failure(); + } FailureOr buffer = getBuffer(rewriter, barrierOp.getInput(), options); if (failed(buffer)) { @@ -78,8 +84,8 @@ struct TensorBarrierOpBufferizationInterface void registerIREEGPUBufferizationInterfaces(DialectRegistry ®istry) { registry.addExtension( +[](MLIRContext *context, IREE::GPU::IREEGPUDialect *dialect) { - IREE::GPU::TensorBarrierOp::attachInterface< - TensorBarrierOpBufferizationInterface>(*context); + IREE::GPU::ValueBarrierOp::attachInterface< + ValueBarrierOpBufferizationInterface>(*context); }); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index a01ea4512b8b..5e982a8b01f8 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -11,8 +11,6 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" @@ -406,4 +404,28 @@ void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } +//===----------------------------------------------------------------------===// +// VectorBarrierOp Lowering +//===----------------------------------------------------------------------===// + +namespace { +struct LowerValueBarrierPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(IREE::GPU::ValueBarrierOp barrier, + PatternRewriter &rewriter) const override { + if (barrier.hasTensorSemantics()) { + return failure(); + } + rewriter.create(barrier.getLoc()); + rewriter.replaceOp(barrier, barrier.getInput()); + return success(); + } +}; +} // namespace + +void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h index 56d158954095..0e2afa301baa 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h @@ -39,6 +39,8 @@ void populateIREEGPUVectorUnrollPatterns( void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns); +void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns); + } // namespace mlir::iree_compiler::IREE::GPU #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H_