Skip to content

Commit

Permalink
[Util][GPU] Add TiedOpInterface implementation for iree_gpu.multi_mma (
Browse files Browse the repository at this point in the history
…iree-org#18626)

This PR is a part of what was originally
iree-org#18608. The PR implements the
TiedOpInterface for the iree_gpu.multi_mma op. This is a temporary
solution to having multi_mma ops before dispatch workgroup creation, and
is only needed right now because we rely on early materialization. This
will enable e2e matmul tests with GPU data tiling while it is still
being developed, and this change can be dropped once we switch to late
materialization.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 authored Sep 30, 2024
1 parent b7ac442 commit a9c7ec1
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,40 @@ util.func public @foo(%argA: tensor<?x?xf32>, %argB: tensor<5x10xf32>, %argC: te
// CHECK: util.return %[[r0]], %[[r1]]
util.return %r0, %r1 : tensor<?x?xf32>, tensor<5x11xf32>
}

// -----

// TODO(Max191): Remove this test once GPU data tiling stops using early
// materialization.
util.func public @multi_mma(
%arg0: tensor<4x16x8x4x16x2x4xf16>,
%arg1: tensor<4x16x4x2x4x16x2x4xf16>,
%arg2: tensor<4x4x8x4x2x4x16x4xf32>) -> (tensor<4x4x8x4x2x4x16x4xf32>) {
%9 = flow.dispatch.region -> (tensor<4x4x8x4x2x4x16x4xf32>) {
%13 = iree_gpu.multi_mma %arg0, %arg1, %arg2 {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>}
: tensor<4x16x8x4x16x2x4xf16>, tensor<4x16x4x2x4x16x2x4xf16> into tensor<4x4x8x4x2x4x16x4xf32>
flow.return %13 : tensor<4x4x8x4x2x4x16x4xf32>
}
util.return %9 : tensor<4x4x8x4x2x4x16x4xf32>
}

// CHECK-LABEL: util.func public @multi_mma(
// CHECK: %[[arg0:.*]]: tensor<4x16x8x4x16x2x4xf16>, %[[arg1:.*]]: tensor<4x16x4x2x4x16x2x4xf16>, %[[arg2:.*]]: tensor<4x4x8x4x2x4x16x4xf32>
// CHECK: %[[r0:.*]] = flow.dispatch.workgroups(%[[arg0]], %[[arg1]], %[[arg2]])
// CHECK-SAME: : (tensor<4x16x8x4x16x2x4xf16>, tensor<4x16x4x2x4x16x2x4xf16>, tensor<4x4x8x4x2x4x16x4xf32>)
// CHECK-NEXT: (%[[arg3:.*]]: !flow.dispatch.tensor<readonly:tensor<4x16x8x4x16x2x4xf16>>,
// CHECK-SAME: %[[arg4:.*]]: !flow.dispatch.tensor<readonly:tensor<4x16x4x2x4x16x2x4xf16>>,
// CHECK-SAME: %[[arg5:.*]]: !flow.dispatch.tensor<readwrite:tensor<4x4x8x4x2x4x16x4xf32>>)
// CHECK-DAG: %[[loadLHS:.*]] = flow.dispatch.tensor.load %[[arg3]]
// CHECK-DAG: %[[loadRHS:.*]] = flow.dispatch.tensor.load %[[arg4]]
// CHECK-DAG: %[[loadACC:.*]] = flow.dispatch.tensor.load %[[arg5]]
// CHECK: %[[MULTI_MMA:.*]] = iree_gpu.multi_mma %[[loadLHS]], %[[loadRHS]], %[[loadACC]]
// CHECK: flow.dispatch.tensor.store %[[MULTI_MMA]], %[[arg5]]
// CHECK: flow.return
// CHECK: }
// CHECK: util.return %[[r0]]
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_compiler_cc_library(
"UtilExternalModels.h",
],
deps = [
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ iree_cc_library(
MLIRMLProgramDialect
MLIRTensorDialect
MLIRValueBoundsOpInterface
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Dialect::Encoding::IR
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "iree/compiler/ExternalInterfaces/UtilExternalModels.h"

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
Expand Down Expand Up @@ -168,6 +170,27 @@ struct LinalgOpTiedOpInterfaceHelper {
}
};

// TODO(Max191): Remove this interface once GPU data tiling stops using early
// materialization. This only exists for handling multi_mma ops before dispatch
// workgroups are created, which only happens with early materialization.
struct MultiMmaOpTiedOpInterface
: public IREE::Util::TiedOpInterface::ExternalModel<
MultiMmaOpTiedOpInterface, IREE::GPU::MultiMmaOp> {
Value getTiedResult(Operation *op, unsigned resultIndex) const {
auto linalgOp = cast<IREE::GPU::MultiMmaOp>(op);
return IREE::Util::TiedOpInterface::findTiedBaseValue(linalgOp.getAcc());
}

::std::optional<unsigned>
getTiedResultOperandIndex(Operation *op, unsigned resultIndex) const {
return {2}; // acc
}

SmallVector<int64_t> getTiedResultOperandIndices(Operation *op) const {
return {2}; // acc
}
};

//===----------------------------------------------------------------------===//
// HoistableOpInterface
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -289,6 +312,11 @@ void registerUtilExternalModels(DialectRegistry &registry) {
*context);
});

registry.addExtension(+[](MLIRContext *context,
IREE::GPU::IREEGPUDialect *dialect) {
IREE::GPU::MultiMmaOp::attachInterface<MultiMmaOpTiedOpInterface>(*context);
});

registry.addExtension(
+[](MLIRContext *context, linalg::LinalgDialect *dialect) {
// Register all Linalg structured ops. `LinalgOp` is an interface and it
Expand Down

0 comments on commit a9c7ec1

Please sign in to comment.