Skip to content

Commit

Permalink
[Codegen][GPU] Add pattern to lower iree_gpu.multi_mma to intrinsics (i…
Browse files Browse the repository at this point in the history
…ree-org#17457)

This takes multi_mma ops that have already been unrolled to a single
intrinsic and simply lowers them to the intrinsic using the interfaced
attribute carried by the op. Inserts shape casts at the boundary to
match the expected type of the intrinsic.
  • Loading branch information
qedawkins authored May 28, 2024
1 parent ab8f668 commit 3d1364e
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ void transform_dialect::ApplyDropMultiMmaOpUnitDims::populatePatterns(
IREE::GPU::populateIREEGPUDropUnitDimsPatterns(patterns);
}

//===---------------------------------------------------------------------===//
// ApplyLowerMultiMmaOp
//===---------------------------------------------------------------------===//

void transform_dialect::ApplyLowerMultiMmaOp::populatePatterns(
RewritePatternSet &patterns) {
IREE::GPU::populateIREEGPULowerMultiMmaPatterns(patterns);
}

//===---------------------------------------------------------------------===//
// ApplyLowerValueBarrierOp
//===---------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def ApplyDropMultiMmaOpUnitDims : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyLowerMultiMmaOp : Op<Transform_Dialect,
"apply_patterns.iree.lower_multi_mma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate patterns to lowering multi_mma ops to the intrinsic specified by
the |kind| attribute.
}];

let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
let assemblyFormat = "attr-dict";
}

def ApplyLowerValueBarrierOp : Op<Transform_Dialect,
"apply_patterns.iree.lower_value_barrier",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ iree_lit_test_suite(
srcs = enforce_glob(
[
"drop_multi_mma_unit_dims.mlir",
"lower_multi_mma.mlir",
"lower_vector_barrier.mlir",
"transform_fuse_forall.mlir",
"vectorize_multi_mma.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_lit_test_suite(
lit
SRCS
"drop_multi_mma_unit_dims.mlir"
"lower_multi_mma.mlir"
"lower_vector_barrier.mlir"
"transform_fuse_forall.mlir"
"unroll_multi_mma.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s

#contraction_accesses = [
affine_map<() -> ()>,
affine_map<() -> ()>,
affine_map<() -> ()>
]
func.func @lower_multi_mma_mfma_16x16x16(%lhs: vector<4xf16>, %rhs: vector<4xf16>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
indexing_maps = #contraction_accesses,
iterator_types = [],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
} : vector<4xf16>, vector<4xf16> into vector<4xf32>
return %0 : vector<4xf32>
}

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.iree.lower_multi_mma
} : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @lower_multi_mma_mfma_16x16x16
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: vector<4xf16>
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: vector<4xf16>
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: vector<4xf32>
// CHECK: amdgpu.mfma %[[LHS]] * %[[RHS]] + %[[ACC]]
// CHECK-SAME: blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
// CHECK-SAME: blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>

// -----

#contraction_accesses = [
affine_map<() -> ()>,
affine_map<() -> ()>,
affine_map<() -> ()>
]
func.func @lower_multi_mma_mfma_32x32x8(%lhs: vector<4xf16>, %rhs: vector<4xf16>, %acc: vector<16xf32>) -> vector<16xf32> {
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
indexing_maps = #contraction_accesses,
iterator_types = [],
kind = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
} : vector<4xf16>, vector<4xf16> into vector<16xf32>
return %0 : vector<16xf32>
}

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.iree.lower_multi_mma
} : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @lower_multi_mma_mfma_32x32x8
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: vector<4xf16>
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: vector<4xf16>
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: vector<16xf32>
// CHECK: amdgpu.mfma %[[LHS]] * %[[RHS]] + %[[ACC]]
// CHECK-SAME: blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32
// CHECK-SAME: blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>

// -----

#contraction_accesses = [
affine_map<() -> ()>,
affine_map<() -> ()>,
affine_map<() -> ()>
]
func.func @lower_multi_mma_wmma_16x16x16(%lhs: vector<16xf16>, %rhs: vector<16xf16>, %acc: vector<8xf32>) -> vector<8xf32> {
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
indexing_maps = #contraction_accesses,
iterator_types = [],
kind = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>
} : vector<16xf16>, vector<16xf16> into vector<8xf32>
return %0 : vector<8xf32>
}

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.iree.lower_multi_mma
} : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @lower_multi_mma_wmma_16x16x16
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: vector<16xf16>
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: vector<16xf16>
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: vector<8xf32>
// CHECK: amdgpu.wmma %[[LHS]] * %[[RHS]] + %[[ACC]]
// CHECK-SAME: : vector<16xf16>, vector<16xf16>, vector<8xf32>

// -----

#contraction_accesses = [
affine_map<() -> ()>,
affine_map<() -> ()>,
affine_map<() -> ()>
]
func.func @lower_multi_mma_mfma_shape_cast_16x16x16(%lhs: vector<1x4xf16>, %rhs: vector<4x1xf16>, %acc: vector<4x1xf32>) -> vector<4x1xf32> {
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
indexing_maps = #contraction_accesses,
iterator_types = [],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
} : vector<1x4xf16>, vector<4x1xf16> into vector<4x1xf32>
return %0 : vector<4x1xf32>
}

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.iree.lower_multi_mma
} : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @lower_multi_mma_mfma_shape_cast_16x16x16
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: vector<1x4xf16>
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: vector<4x1xf16>
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: vector<4x1xf32>
// CHECK-DAG: %[[LHSCAST:.+]] = vector.shape_cast %[[LHS]] : vector<1x4xf16> to vector<4xf16>
// CHECK-DAG: %[[RHSCAST:.+]] = vector.shape_cast %[[RHS]] : vector<4x1xf16> to vector<4xf16>
// CHECK-DAG: %[[ACCCAST:.+]] = vector.shape_cast %[[ACC]] : vector<4x1xf32> to vector<4xf32>
// CHECK: %[[MMA:.+]] = amdgpu.mfma %[[LHSCAST]] * %[[RHSCAST]] + %[[ACCCAST]]
// CHECK-SAME: blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
// CHECK-SAME: blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: vector.shape_cast %[[MMA]] : vector<4xf32> to vector<4x1xf32>
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,59 @@ LogicalResult fuseForallIntoSlice(RewriterBase &rewriter,
return success();
}

//===----------------------------------------------------------------------===//
// MultiMmaOp Lowering
//===----------------------------------------------------------------------===//

namespace {
struct LowerMultiMmaPattern : public OpRewritePattern<IREE::GPU::MultiMmaOp> {
using OpRewritePattern<IREE::GPU::MultiMmaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IREE::GPU::MultiMmaOp mmaOp,
PatternRewriter &rewriter) const override {
if (mmaOp.hasTensorSemantics()) {
return rewriter.notifyMatchFailure(
mmaOp, "lowering to concrete op requires vector semantics");
}
SmallVector<int64_t> bounds;
mmaOp.getIterationBounds(bounds);
if (!bounds.empty()) {
return rewriter.notifyMatchFailure(mmaOp,
"must be a single mma operation");
}

auto [lhsVectorType, rhsVectorType, accVectorType] =
mmaOp.getKind().getABCVectorTypes();

Value aCast = mmaOp.getLhs();
Value bCast = mmaOp.getRhs();
Value cCast = mmaOp.getAcc();
if (aCast.getType() != lhsVectorType) {
aCast = rewriter.create<vector::ShapeCastOp>(mmaOp.getLoc(),
lhsVectorType, aCast);
}
if (bCast.getType() != rhsVectorType) {
bCast = rewriter.create<vector::ShapeCastOp>(mmaOp.getLoc(),
rhsVectorType, bCast);
}
if (cCast.getType() != accVectorType) {
cCast = rewriter.create<vector::ShapeCastOp>(mmaOp.getLoc(),
accVectorType, cCast);
}

FailureOr<Value> concreteMmaOp = mmaOp.getKind().buildMmaOperation(
rewriter, mmaOp.getLoc(), cCast.getType(), aCast, bCast, cCast);
assert(succeeded(concreteMmaOp) && "Failed to create mma op");
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
mmaOp, mmaOp.getAcc().getType(), *concreteMmaOp);
return success();
}
};
} // namespace

void populateIREEGPULowerMultiMmaPatterns(RewritePatternSet &patterns) {
patterns.add<LowerMultiMmaPattern>(patterns.getContext());
}

//===----------------------------------------------------------------------===//
// MultiMmaOp Unit Dim Folding
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ LogicalResult fuseForallIntoSlice(RewriterBase &rewriter,
tensor::ExtractSliceOp slice);

void populateIREEGPUDropUnitDimsPatterns(RewritePatternSet &patterns);
void populateIREEGPULowerMultiMmaPatterns(RewritePatternSet &patterns);
void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns);
void populateIREEGPUVectorUnrollPatterns(
RewritePatternSet &patterns, const vector::UnrollVectorOptions &options);
Expand Down

0 comments on commit 3d1364e

Please sign in to comment.