Skip to content

Commit

Permalink
[Codegen][GPU] Add pass to fuse and hoist scf.forall ops (iree-org#17505
Browse files Browse the repository at this point in the history
)

This pass greedily fuses parallel loops together and tries to hoist them
out of serial loops. It is left as TODO to include greedy fusion of
untiled consumers.
  • Loading branch information
qedawkins authored May 28, 2024
1 parent 29e70ab commit aa0bc40
Show file tree
Hide file tree
Showing 12 changed files with 304 additions and 1 deletion.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses",
"//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms",
"//compiler/src/iree/compiler/Codegen/LLVMCPU",
"//compiler/src/iree/compiler/Codegen/LLVMGPU",
"//compiler/src/iree/compiler/Codegen/SPIRV",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_cc_library(
iree::compiler::Codegen::Common::CPU::CommonCPUPasses
iree::compiler::Codegen::Common::GPU::CommonGPUPasses
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms
iree::compiler::Codegen::LLVMCPU
iree::compiler::Codegen::LLVMGPU
iree::compiler::Codegen::SPIRV
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")

package(
default_visibility = ["//visibility:public"],
Expand Down Expand Up @@ -33,23 +33,46 @@ iree_compiler_cc_library(
],
)

iree_gentbl_cc_library(
name = "PassesIncGen",
tbl_outs = [
(
["--gen-pass-decls"],
"Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)

iree_compiler_cc_library(
name = "GPUTransforms",
srcs = [
"FuseAndHoistParallelLoops.cpp",
"Passes.cpp",
"Transforms.cpp",
],
hdrs = [
"Passes.h",
"Passes.h.inc",
"Transforms.h",
],
deps = [
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Transforms",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,37 @@ iree_cc_library(
PUBLIC
)

iree_tablegen_library(
NAME
PassesIncGen
TD_FILE
"Passes.td"
OUTS
--gen-pass-decls Passes.h.inc
)

iree_cc_library(
NAME
GPUTransforms
HDRS
"Passes.h"
"Passes.h.inc"
"Transforms.h"
SRCS
"FuseAndHoistParallelLoops.cpp"
"Passes.cpp"
"Transforms.cpp"
DEPS
::PassesIncGen
LLVMSupport
MLIRAffineDialect
MLIRAffineUtils
MLIRArithDialect
MLIRFuncDialect
MLIRFunctionInterfaces
MLIRGPUDialect
MLIRIR
MLIRPass
MLIRSCFDialect
MLIRSupport
MLIRTensorDialect
Expand All @@ -52,6 +68,7 @@ iree_cc_library(
MLIRVectorTransforms
MLIRVectorUtils
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::Transforms
PUBLIC
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::GPU {

#define GEN_PASS_DEF_FUSEANDHOISTPARALLELLOOPSPASS
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"

namespace {
struct FuseAndHoistParallelLoopsPass final
: impl::FuseAndHoistParallelLoopsPassBase<FuseAndHoistParallelLoopsPass> {
void runOnOperation() override;
};
} // namespace

struct FuseForalls final : OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto sliceParent = sliceOp->getParentOfType<scf::ForallOp>();
if (!sliceParent) {
return failure();
}

auto producerForall = sliceOp.getSource().getDefiningOp<scf::ForallOp>();
if (!producerForall) {
return failure();
}

// TODO: Allow extracting multiple uses within the same consumer loop. Still
// single producer single consumer loop, but multiple uses within the
// consumer.
if (!producerForall->hasOneUse()) {
return failure();
}

return fuseForallIntoSlice(rewriter, producerForall, sliceParent, sliceOp);
}
};

void FuseAndHoistParallelLoopsPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

// These two patterns are run to a fixed point, allowing fusion within
// potentially nested loops, hoisting from said loops, and continued fusion.
patterns.add<FuseForalls>(context);
populateForallLoopHoistingPattern(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}

} // namespace mlir::iree_compiler::IREE::GPU
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"

namespace mlir::iree_compiler {

namespace IREE::GPU {
namespace {
#define GEN_PASS_REGISTRATION
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"
} // namespace
} // namespace IREE::GPU

void registerIREEGPUPasses() {
// Generated.
IREE::GPU::registerPasses();
}
} // namespace mlir::iree_compiler
24 changes: 24 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_
#define IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_

#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler::IREE::GPU {
#define GEN_PASS_DECL
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc" // IWYU pragma: keep
} // namespace mlir::iree_compiler::IREE::GPU

namespace mlir::iree_compiler {
/// Register GPU passes.
void registerIREEGPUPasses();
} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES
#define IREE_CODEGEN_DIALECT_GPU_TRANSFORMS_PASSES

include "mlir/Pass/PassBase.td"

def FuseAndHoistParallelLoopsPass :
InterfacePass<"iree-gpu-fuse-and-hoist-parallel-loops", "mlir::FunctionOpInterface"> {
let summary = "Checks GPU specific resource usage constraints like shared memory limits";
let dependentDialects = [
"::mlir::affine::AffineDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect"
];
}

#endif // IREE_CODEGEN_DIALECt_GPU_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Tests for iree_gpu transforms.

load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")

package(
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)

iree_lit_test_suite(
name = "lit",
srcs = enforce_glob(
[
"fuse_and_hoist_forall.mlir",
],
include = ["*.mlir"],
),
cfg = "//compiler:lit.cfg.py",
tools = [
"//tools:iree-opt",
"@llvm-project//llvm:FileCheck",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
# compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
# #
# To disable autogeneration for this file entirely, delete this header. #
################################################################################

iree_add_all_subdirs()

iree_lit_test_suite(
NAME
lit
SRCS
"fuse_and_hoist_forall.mlir"
TOOLS
FileCheck
iree-opt
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-fuse-and-hoist-parallel-loops))' --split-input-file | FileCheck %s

#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 4)>
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
#map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
#map4 = affine_map<(d0) -> (d0 * 16)>
module {
func.func @forall_fuse_then_hoist() {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> -> tensor<128x128xf32>
%6 = tensor.empty() : tensor<128x4xf16>
%7 = tensor.empty() : tensor<4x128xf16>
%8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<128x128xf32>) {
%9 = scf.forall (%arg2, %arg3) in (64, 1) shared_outs(%arg4 = %6) -> (tensor<128x4xf16>) {
%12 = affine.apply #map(%arg2)
%13 = affine.apply #map1(%arg3)
%14 = affine.apply #map(%arg2)
%15 = affine.apply #map2(%arg3)[%arg0]
%extracted_slice = tensor.extract_slice %3[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16>
%extracted_slice_0 = tensor.extract_slice %arg4[%12, %13] [2, 4] [1, 1] : tensor<128x4xf16> to tensor<2x4xf16>
%16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16>
}
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
%10 = scf.forall (%arg2, %arg3) in (2, 32) shared_outs(%arg4 = %7) -> (tensor<4x128xf16>) {
%12 = affine.apply #map(%arg2)
%13 = affine.apply #map1(%arg3)
%14 = affine.apply #map3(%arg2)[%arg0]
%15 = affine.apply #map1(%arg3)
%extracted_slice = tensor.extract_slice %4[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16>
%extracted_slice_0 = tensor.extract_slice %arg4[%12, %13] [2, 4] [1, 1] : tensor<4x128xf16> to tensor<2x4xf16>
%16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<4x128xf16>
}
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
%11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) {
%12 = affine.apply #map4(%arg2)
%13 = affine.apply #map4(%arg3)
%extracted_slice = tensor.extract_slice %9[%12, 0] [16, 4] [1, 1] : tensor<128x4xf16> to tensor<16x4xf16>
%extracted_slice_0 = tensor.extract_slice %10[0, %13] [4, 16] [1, 1] : tensor<4x128xf16> to tensor<4x16xf16>
%extracted_slice_1 = tensor.extract_slice %arg4[%12, %13] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
%14 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<16x4xf16>, tensor<4x16xf16>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
}
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
scf.yield %11 : tensor<128x128xf32>
}
flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
return
}
}

// CHECK-LABEL: func @forall_fuse_then_hoist
// CHECK: %[[OUTER_PARALLEL:.+]] = scf.forall
// CHECK: %[[LOOP:.+]] = scf.for
// CHECK: scf.yield {{.*}} : tensor<16x16xf32>
// CHECK: scf.forall.in_parallel
// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]]
// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]]
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
Expand All @@ -33,6 +34,7 @@ void registerCodegenPasses() {
registerCodegenSPIRVPasses();
registerCodegenVMVXPasses();
registerCodegenWGSLPasses();
registerIREEGPUPasses();
}

} // namespace mlir::iree_compiler

0 comments on commit aa0bc40

Please sign in to comment.