Skip to content

Commit

Permalink
[Codegen][GPU] Allow iree_gpu.tensor_barrier to take vectors (iree-or…
Browse files Browse the repository at this point in the history
…g#17479)

This allows synchronizing on vectors as well as tensors with similar
semantics. In a typical lowering flow, this will represent the
read equivalent to a tensor barrier, in that a tensor barrier represents
a wait until all writes to a shared allocation has finished, while this
represents a wait until all threads have read the value they need from
that shared allocation.

Renames the operation to iree_gpu.value_barrier for clarity.
  • Loading branch information
qedawkins authored May 27, 2024
1 parent 1750e2b commit 6c75aa1
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2580,7 +2580,7 @@ func.func @tensor_barrier() -> vector<2xf32> {
%c0 = arith.constant 0 : index
%alloc = bufferization.alloc_tensor() : tensor<2xf32>
%tmp = vector.transfer_write %cst, %alloc[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32>
%barrier = iree_gpu.tensor_barrier %tmp : tensor<2xf32>
%barrier = iree_gpu.value_barrier %tmp : tensor<2xf32>
%res = vector.transfer_read %barrier[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32>
return %res : vector<2xf32>
}
Expand All @@ -2601,7 +2601,7 @@ func.func @tensor_barrier_in_loop() -> vector<2xf32> {
%alloc = bufferization.alloc_tensor() : tensor<2xf32>
%loop = scf.for %arg0 = %c0 to %c10 step %c1 iter_args(%init = %alloc) -> tensor<2xf32> {
%tmp = vector.transfer_write %cst, %init[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32>
%barrier = iree_gpu.tensor_barrier %tmp : tensor<2xf32>
%barrier = iree_gpu.value_barrier %tmp : tensor<2xf32>
scf.yield %barrier : tensor<2xf32>
}
%res = vector.transfer_read %loop[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32>
Expand All @@ -2614,3 +2614,23 @@ func.func @tensor_barrier_in_loop() -> vector<2xf32> {
// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: }
// CHECK: vector.transfer_read %[[ALLOC]]

// -----

func.func @vector_barrier() -> vector<2xf32> {
%cst = arith.constant dense<0.0> : vector<2xf32>
%cst0 = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
%alloc = bufferization.alloc_tensor() : tensor<2xf32>
%tmp = vector.transfer_write %cst, %alloc[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32>
%read = vector.transfer_read %tmp[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32>
%barrier = iree_gpu.value_barrier %read : vector<2xf32>
return %barrier : vector<2xf32>
}

// Verify that the dual-modes of `value_barrier` are adhered to.
// CHECK-LABEL: func @vector_barrier()
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<2xf32>
// CHECK: vector.transfer_write %{{.*}}, %[[ALLOC]]
// CHECK-NEXT: %[[RD:.+]] = vector.transfer_read %[[ALLOC]]
// CHECK-NEXT: iree_gpu.value_barrier %[[RD]]
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,4 @@ LogicalResult ShuffleTensorOp::verifyRegions() {
return success();
}

//===----------------------------------------------------------------------===//
// TensorBarrierOp
//===----------------------------------------------------------------------===//

MutableOperandRange TensorBarrierOp::getDpsInitsMutable() {
return getInputMutable();
}

} // namespace mlir::iree_compiler::IREE::GPU
25 changes: 14 additions & 11 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -352,36 +352,39 @@ def IREEGPU_ShuffleTensorOp : Op<IREEGPU_Dialect, "shuffle_tensor", [
}

//===----------------------------------------------------------------------===//
// TensorBarrierOp
// ValueBarrierOp
//===----------------------------------------------------------------------===//

def IREEGPU_TensorBarrierOp : Op<IREEGPU_Dialect, "tensor_barrier", [
def IREEGPU_ValueBarrierOp : Op<IREEGPU_Dialect, "value_barrier", [
Pure,
DeclareOpInterfaceMethods<DestinationStyleOpInterface>,
AllTypesMatch<["input", "result"]>,
]> {
let summary = "Shuffles a private tensor across a shared allocation";
let description = [{
This operation acts as a barrier on a tensor value. It takes a single
tensor operand and produces an equivalent tensor. This does not have copy
and/or data movement semantics and simply represents a barrier on all writes
to the input tensor.
This operation acts as a barrier on a value semantic SSA value (tensor or
vector). It takes a single operand and produces a value equivalent to the
input. This does not have copy and/or data movement semantics and simply
represents a barrier on all writes in the tensor case, and a barrier until
all threads acquire the input vector in the vector case.

This operation is a no-op when not present in a parallel context. This
operation is pure as it only requires synchronization for the value it
produces.
}];

let arguments = (ins AnyRankedTensor:$input);
let results = (outs AnyRankedTensor:$result);
let arguments = (ins AnyRankedTensorOrVector:$input);
let results = (outs AnyRankedTensorOrVector:$result);

let assemblyFormat = [{
$input attr-dict `:` type($result)
}];

let extraClassDeclaration = [{
RankedTensorType getInputType() {
return getInput().getType();
bool hasTensorSemantics() {
return isa<::mlir::RankedTensorType>(getInput().getType());
}
::mlir::ShapedType getInputType() {
return ::llvm::cast<::mlir::ShapedType>(getInput().getType());
}
}];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,21 @@ func.func @single_multi_mma(%lhs: vector<4xf16>, %rhs: vector<4xf16>, %acc: vect
// -----

func.func @tensor_barrier(%input: tensor<?xf16>) -> tensor<?xf16> {
%out = iree_gpu.tensor_barrier %input : tensor<?xf16>
%out = iree_gpu.value_barrier %input : tensor<?xf16>
return %out : tensor<?xf16>
}

// CHECK-LABEL: func @tensor_barrier
// CHECK-SAME: %[[INPUT:[A-Za-z0-9]+]]: tensor<?xf16>
// CHECK: iree_gpu.tensor_barrier %[[INPUT]] : tensor<?xf16>
// CHECK: iree_gpu.value_barrier %[[INPUT]] : tensor<?xf16>

// -----

func.func @vector_barrier(%input: vector<8xf16>) -> vector<8xf16> {
%out = iree_gpu.value_barrier %input : vector<8xf16>
return %out : vector<8xf16>
}

// CHECK-LABEL: func @vector_barrier
// CHECK-SAME: %[[INPUT:[A-Za-z0-9]+]]: vector<8xf16>
// CHECK: iree_gpu.value_barrier %[[INPUT]] : vector<8xf16>
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ transform_dialect::IREEGPUExtensions::IREEGPUExtensions() {
>();
}

//===---------------------------------------------------------------------===//
// ApplyLowerValueBarrierOp
//===---------------------------------------------------------------------===//

void transform_dialect::ApplyLowerValueBarrierOp::populatePatterns(
RewritePatternSet &patterns) {
IREE::GPU::populateIREEGPULowerValueBarrierPatterns(patterns);
}

//===---------------------------------------------------------------------===//
// ApplyUnrollMultiMmaOp
//===---------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def ApplyLowerValueBarrierOp : Op<Transform_Dialect,
"apply_patterns.iree.lower_value_barrier",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate patterns to convert value barriers on vectors into gpu.barrier ops.
Barriers on tensors are ignored.
}];

let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
let assemblyFormat = "attr-dict";
}

def ApplyUnrollMultiMmaOp : Op<Transform_Dialect,
"apply_patterns.iree.unroll_multi_mma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_lit_test_suite(
name = "lit",
srcs = enforce_glob(
[
"lower_vector_barrier.mlir",
"transform_fuse_forall.mlir",
"vectorize_multi_mma.mlir",
"unroll_multi_mma.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ iree_lit_test_suite(
NAME
lit
SRCS
"lower_vector_barrier.mlir"
"transform_fuse_forall.mlir"
"unroll_multi_mma.mlir"
"vectorize_multi_mma.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s

func.func @lower_value_barrier(%input: vector<4xf32>) -> vector<4xf32> {
%0 = iree_gpu.value_barrier %input : vector<4xf32>
return %0 : vector<4xf32>
}

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.iree.lower_value_barrier
} : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @lower_value_barrier
// CHECK-SAME: %[[INPUT:[A-Za-z0-9]+]]: vector<4xf32>
// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: return %[[INPUT]]
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ namespace {

/// Bufferization of iree_gpu.tensor_barrier. Always just bufferizes in place
/// and replaces with a barrier.
struct TensorBarrierOpBufferizationInterface
struct ValueBarrierOpBufferizationInterface
: public BufferizableOpInterface::ExternalModel<
TensorBarrierOpBufferizationInterface, IREE::GPU::TensorBarrierOp> {
ValueBarrierOpBufferizationInterface, IREE::GPU::ValueBarrierOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// This op never needs to bufferize to a copy.
Expand All @@ -47,8 +47,11 @@ struct TensorBarrierOpBufferizationInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto barrierOp = cast<IREE::GPU::TensorBarrierOp>(op);
auto barrierOp = cast<IREE::GPU::ValueBarrierOp>(op);
assert(value == barrierOp.getResult() && "invalid value");
if (!barrierOp.hasTensorSemantics()) {
return failure();
}
auto srcMemrefType = bufferization::getBufferType(barrierOp.getInput(),
options, invocationStack);
if (failed(srcMemrefType))
Expand All @@ -58,7 +61,10 @@ struct TensorBarrierOpBufferizationInterface

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto barrierOp = cast<IREE::GPU::TensorBarrierOp>(op);
auto barrierOp = cast<IREE::GPU::ValueBarrierOp>(op);
if (!barrierOp.hasTensorSemantics()) {
return failure();
}
FailureOr<Value> buffer =
getBuffer(rewriter, barrierOp.getInput(), options);
if (failed(buffer)) {
Expand All @@ -78,8 +84,8 @@ struct TensorBarrierOpBufferizationInterface
void registerIREEGPUBufferizationInterfaces(DialectRegistry &registry) {
registry.addExtension(
+[](MLIRContext *context, IREE::GPU::IREEGPUDialect *dialect) {
IREE::GPU::TensorBarrierOp::attachInterface<
TensorBarrierOpBufferizationInterface>(*context);
IREE::GPU::ValueBarrierOp::attachInterface<
ValueBarrierOpBufferizationInterface>(*context);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
Expand Down Expand Up @@ -406,4 +404,28 @@ void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns) {
patterns.add<VectorizeStaticMultiMmaOpPattern>(patterns.getContext());
}

//===----------------------------------------------------------------------===//
// VectorBarrierOp Lowering
//===----------------------------------------------------------------------===//

namespace {
struct LowerValueBarrierPattern
: public OpRewritePattern<IREE::GPU::ValueBarrierOp> {
using OpRewritePattern<IREE::GPU::ValueBarrierOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IREE::GPU::ValueBarrierOp barrier,
PatternRewriter &rewriter) const override {
if (barrier.hasTensorSemantics()) {
return failure();
}
rewriter.create<gpu::BarrierOp>(barrier.getLoc());
rewriter.replaceOp(barrier, barrier.getInput());
return success();
}
};
} // namespace

void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns) {
patterns.add<LowerValueBarrierPattern>(patterns.getContext());
}

} // namespace mlir::iree_compiler::IREE::GPU
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ void populateIREEGPUVectorUnrollPatterns(

void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns);

void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns);

} // namespace mlir::iree_compiler::IREE::GPU

#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H_

0 comments on commit 6c75aa1

Please sign in to comment.