From aa0bc40235adafabbfc1d73cac54b97dba856841 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Tue, 28 May 2024 18:46:27 -0400 Subject: [PATCH] [Codegen][GPU] Add pass to fuse and hoist scf.forall ops (#17505) This pass greedily fuses parallel loops together and tries to hoist them out of serial loops. It is left as TODO to include greedy fusion of untiled consumers. --- .../src/iree/compiler/Codegen/BUILD.bazel | 1 + .../src/iree/compiler/Codegen/CMakeLists.txt | 1 + .../Dialect/GPU/Transforms/BUILD.bazel | 25 ++++++- .../Dialect/GPU/Transforms/CMakeLists.txt | 17 +++++ .../Transforms/FuseAndHoistParallelLoops.cpp | 68 ++++++++++++++++++ .../Codegen/Dialect/GPU/Transforms/Passes.cpp | 23 ++++++ .../Codegen/Dialect/GPU/Transforms/Passes.h | 24 +++++++ .../Codegen/Dialect/GPU/Transforms/Passes.td | 21 ++++++ .../Dialect/GPU/Transforms/test/BUILD.bazel | 30 ++++++++ .../GPU/Transforms/test/CMakeLists.txt | 23 ++++++ .../test/fuse_and_hoist_forall.mlir | 70 +++++++++++++++++++ compiler/src/iree/compiler/Codegen/Passes.cpp | 2 + 12 files changed, 304 insertions(+), 1 deletion(-) create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir diff --git a/compiler/src/iree/compiler/Codegen/BUILD.bazel b/compiler/src/iree/compiler/Codegen/BUILD.bazel index 31af81b93c1b..a07556c695fc 100644 --- a/compiler/src/iree/compiler/Codegen/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/BUILD.bazel @@ -25,6 +25,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses", "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", + "//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms", "//compiler/src/iree/compiler/Codegen/LLVMCPU", "//compiler/src/iree/compiler/Codegen/LLVMGPU", "//compiler/src/iree/compiler/Codegen/SPIRV", diff --git a/compiler/src/iree/compiler/Codegen/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/CMakeLists.txt index ae59fc660095..bf8407b1119c 100644 --- a/compiler/src/iree/compiler/Codegen/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/CMakeLists.txt @@ -23,6 +23,7 @@ iree_cc_library( iree::compiler::Codegen::Common::CPU::CommonCPUPasses iree::compiler::Codegen::Common::GPU::CommonGPUPasses iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect + iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms iree::compiler::Codegen::LLVMCPU iree::compiler::Codegen::LLVMGPU iree::compiler::Codegen::SPIRV diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel index b6a7be2afe85..632630cc5866 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library") package( default_visibility = ["//visibility:public"], @@ -33,23 +33,46 @@ iree_compiler_cc_library( ], ) +iree_gentbl_cc_library( + name = "PassesIncGen", + tbl_outs = [ + ( + ["--gen-pass-decls"], + "Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + iree_compiler_cc_library( name = "GPUTransforms", srcs = [ + "FuseAndHoistParallelLoops.cpp", + "Passes.cpp", "Transforms.cpp", ], hdrs = [ + "Passes.h", + "Passes.h.inc", "Transforms.h", ], deps = [ + ":PassesIncGen", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", + "//compiler/src/iree/compiler/Codegen/Transforms", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt index 58d9364428c1..588ec3f9cbd4 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt @@ -28,21 +28,37 @@ iree_cc_library( PUBLIC ) +iree_tablegen_library( + NAME + PassesIncGen + TD_FILE + "Passes.td" + OUTS + --gen-pass-decls Passes.h.inc +) + iree_cc_library( NAME GPUTransforms HDRS + "Passes.h" + "Passes.h.inc" "Transforms.h" SRCS + "FuseAndHoistParallelLoops.cpp" + "Passes.cpp" "Transforms.cpp" DEPS + ::PassesIncGen LLVMSupport MLIRAffineDialect MLIRAffineUtils MLIRArithDialect MLIRFuncDialect + MLIRFunctionInterfaces MLIRGPUDialect MLIRIR + MLIRPass MLIRSCFDialect MLIRSupport MLIRTensorDialect @@ -52,6 +68,7 @@ iree_cc_library( MLIRVectorTransforms MLIRVectorUtils iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect + iree::compiler::Codegen::Transforms PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp new file mode 100644 index 000000000000..46aa775733db --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp @@ -0,0 +1,68 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h" +#include "iree/compiler/Codegen/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler::IREE::GPU { + +#define GEN_PASS_DEF_FUSEANDHOISTPARALLELLOOPSPASS +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc" + +namespace { +struct FuseAndHoistParallelLoopsPass final + : impl::FuseAndHoistParallelLoopsPassBase { + void runOnOperation() override; +}; +} // namespace + +struct FuseForalls final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + auto sliceParent = sliceOp->getParentOfType(); + if (!sliceParent) { + return failure(); + } + + auto producerForall = sliceOp.getSource().getDefiningOp(); + if (!producerForall) { + return failure(); + } + + // TODO: Allow extracting multiple uses within the same consumer loop. Still + // single producer single consumer loop, but multiple uses within the + // consumer. + if (!producerForall->hasOneUse()) { + return failure(); + } + + return fuseForallIntoSlice(rewriter, producerForall, sliceParent, sliceOp); + } +}; + +void FuseAndHoistParallelLoopsPass::runOnOperation() { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + // These two patterns are run to a fixed point, allowing fusion within + // potentially nested loops, hoisting from said loops, and continued fusion. + patterns.add(context); + populateForallLoopHoistingPattern(patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.cpp new file mode 100644 index 000000000000..45934001b212 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.cpp @@ -0,0 +1,23 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir::iree_compiler { + +namespace IREE::GPU { +namespace { +#define GEN_PASS_REGISTRATION +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc" +} // namespace +} // namespace IREE::GPU + +void registerIREEGPUPasses() { + // Generated. + IREE::GPU::registerPasses(); +} +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h new file mode 100644 index 000000000000..e9fa3268daed --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h @@ -0,0 +1,24 @@ + +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_ +#define IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::GPU { +#define GEN_PASS_DECL +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc" // IWYU pragma: keep +} // namespace mlir::iree_compiler::IREE::GPU + +namespace mlir::iree_compiler { +/// Register GPU passes. +void registerIREEGPUPasses(); +} // namespace mlir::iree_compiler + +#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td new file mode 100644 index 000000000000..25487c1396f0 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td @@ -0,0 +1,21 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES +#define IREE_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def FuseAndHoistParallelLoopsPass : + InterfacePass<"iree-gpu-fuse-and-hoist-parallel-loops", "mlir::FunctionOpInterface"> { + let summary = "Checks GPU specific resource usage constraints like shared memory limits"; + let dependentDialects = [ + "::mlir::affine::AffineDialect", + "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect" + ]; +} + +#endif // IREE_CODEGEN_DIALECt_GPU_TRANSFORMS_PASSES diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel new file mode 100644 index 000000000000..ec698a32b1cc --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel @@ -0,0 +1,30 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Tests for iree_gpu transforms. + +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = enforce_glob( + [ + "fuse_and_hoist_forall.mlir", + ], + include = ["*.mlir"], + ), + cfg = "//compiler:lit.cfg.py", + tools = [ + "//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt new file mode 100644 index 000000000000..bf0669c711db --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt @@ -0,0 +1,23 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "fuse_and_hoist_forall.mlir" + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir new file mode 100644 index 000000000000..f3bd2f75485a --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir @@ -0,0 +1,70 @@ +// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-fuse-and-hoist-parallel-loops))' --split-input-file | FileCheck %s + +#map = affine_map<(d0) -> (d0 * 2)> +#map1 = affine_map<(d0) -> (d0 * 4)> +#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)> +#map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)> +#map4 = affine_map<(d0) -> (d0 * 16)> +module { + func.func @forall_fuse_then_hoist() { + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf16> + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf32> + %6 = tensor.empty() : tensor<128x4xf16> + %7 = tensor.empty() : tensor<4x128xf16> + %8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<128x128xf32>) { + %9 = scf.forall (%arg2, %arg3) in (64, 1) shared_outs(%arg4 = %6) -> (tensor<128x4xf16>) { + %12 = affine.apply #map(%arg2) + %13 = affine.apply #map1(%arg3) + %14 = affine.apply #map(%arg2) + %15 = affine.apply #map2(%arg3)[%arg0] + %extracted_slice = tensor.extract_slice %3[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16> + %extracted_slice_0 = tensor.extract_slice %arg4[%12, %13] [2, 4] [1, 1] : tensor<128x4xf16> to tensor<2x4xf16> + %16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16> + } + } {mapping = [#gpu.thread, #gpu.thread]} + %10 = scf.forall (%arg2, %arg3) in (2, 32) shared_outs(%arg4 = %7) -> (tensor<4x128xf16>) { + %12 = affine.apply #map(%arg2) + %13 = affine.apply #map1(%arg3) + %14 = affine.apply #map3(%arg2)[%arg0] + %15 = affine.apply #map1(%arg3) + %extracted_slice = tensor.extract_slice %4[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16> + %extracted_slice_0 = tensor.extract_slice %arg4[%12, %13] [2, 4] [1, 1] : tensor<4x128xf16> to tensor<2x4xf16> + %16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<4x128xf16> + } + } {mapping = [#gpu.thread, #gpu.thread]} + %11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) { + %12 = affine.apply #map4(%arg2) + %13 = affine.apply #map4(%arg3) + %extracted_slice = tensor.extract_slice %9[%12, 0] [16, 4] [1, 1] : tensor<128x4xf16> to tensor<16x4xf16> + %extracted_slice_0 = tensor.extract_slice %10[0, %13] [4, 16] [1, 1] : tensor<4x128xf16> to tensor<4x16xf16> + %extracted_slice_1 = tensor.extract_slice %arg4[%12, %13] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32> + %14 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<16x4xf16>, tensor<4x16xf16>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32> + } + } {mapping = [#gpu.thread, #gpu.thread]} + scf.yield %11 : tensor<128x128xf32> + } + flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor> + return + } +} + +// CHECK-LABEL: func @forall_fuse_then_hoist +// CHECK: %[[OUTER_PARALLEL:.+]] = scf.forall +// CHECK: %[[LOOP:.+]] = scf.for +// CHECK: scf.yield {{.*}} : tensor<16x16xf32> +// CHECK: scf.forall.in_parallel +// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]] +// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]] diff --git a/compiler/src/iree/compiler/Codegen/Passes.cpp b/compiler/src/iree/compiler/Codegen/Passes.cpp index 07f62f16da6c..236071717d73 100644 --- a/compiler/src/iree/compiler/Codegen/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/Passes.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Codegen/Common/CPU/Passes.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" #include "iree/compiler/Codegen/LLVMCPU/Passes.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" #include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h" @@ -33,6 +34,7 @@ void registerCodegenPasses() { registerCodegenSPIRVPasses(); registerCodegenVMVXPasses(); registerCodegenWGSLPasses(); + registerIREEGPUPasses(); } } // namespace mlir::iree_compiler