forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reapply "[Codegen][GPU] Add range information to GPU dispatch IDs" (i…
…ree-org#19361) (iree-org#19372) This reverts commit cb5be1d. Compaled to the previous revision, this one works around a correctness bug in dataflow analysis that's being fixed by removing the analysis after SCF->CF. --- First, this patch implements InferIntRangeInterface for hal.interface.workgroup.{size,id,count} using a local upper_bound attribute. Then, it adds a -iree-codegen-gpu-propagate-dispatch-size-bounds pass that adds these upper_bounds identifiers to the interface.workgroup operations and to gpu.thread_id based on static information available late in the codegen pipeline. Then, it uses -optimize-int-arithmetic to optimize indexing after -lower-affine, getting rid of a bunch of "if the input's negative" logic that isn't actually needed in many of our kernels. It also ensures that these upper_bound values propagate to LLVM.
- Loading branch information
Showing
16 changed files
with
317 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
103 changes: 103 additions & 0 deletions
103
compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/Common/GPU/Passes.h" | ||
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" | ||
#include "iree/compiler/Codegen/Utils/GPUUtils.h" | ||
#include "iree/compiler/Codegen/Utils/Utils.h" | ||
#include "iree/compiler/Dialect/HAL/IR/HALOps.h" | ||
#include "mlir/Dialect/GPU/IR/GPUDialect.h" | ||
#include "mlir/Interfaces/FunctionInterfaces.h" | ||
#include "mlir/Transforms/Passes.h" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
#define GEN_PASS_DEF_GPUPROPAGATEDISPATCHSIZEBOUNDSPASS | ||
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
static void applyBounds(FunctionOpInterface funcOp, | ||
ArrayRef<int32_t> workgroupSizes, | ||
ArrayRef<int32_t> workgroupCounts) { | ||
Builder b(funcOp->getContext()); | ||
funcOp->walk([&](Operation *op) { | ||
TypeSwitch<Operation *>(op) | ||
.Case([&](gpu::ThreadIdOp tidOp) { | ||
tidOp.setUpperBoundAttr(b.getIndexAttr( | ||
workgroupSizes[static_cast<uint32_t>(tidOp.getDimension())])); | ||
}) | ||
.Case([&](IREE::HAL::InterfaceWorkgroupSizeOp wgSizeOp) { | ||
wgSizeOp.setUpperBoundAttr(b.getIndexAttr( | ||
workgroupSizes[wgSizeOp.getDimension().getZExtValue()])); | ||
}) | ||
.Case([&](IREE::HAL::InterfaceWorkgroupIDOp wgIdOp) { | ||
wgIdOp.setUpperBoundAttr(b.getIndexAttr( | ||
workgroupCounts[wgIdOp.getDimension().getZExtValue()])); | ||
}) | ||
.Case([&](IREE::HAL::InterfaceWorkgroupCountOp wgCountOp) { | ||
wgCountOp.setUpperBoundAttr(b.getIndexAttr( | ||
workgroupCounts[wgCountOp.getDimension().getZExtValue()])); | ||
}) | ||
.Default([](Operation *) {}); | ||
}); | ||
} | ||
|
||
struct GPUPropagateDispatchSizeBoundsPass final | ||
: impl::GPUPropagateDispatchSizeBoundsPassBase< | ||
GPUPropagateDispatchSizeBoundsPass> { | ||
using Base::Base; | ||
|
||
void runOnOperation() override { | ||
FunctionOpInterface funcOp = getOperation(); | ||
IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); | ||
if (!target) { | ||
funcOp.emitWarning("no known target attribute late in GPU codegen"); | ||
return; | ||
} | ||
SmallVector<int32_t, 3> workgroupSizes( | ||
target.getWgp().getMaxWorkgroupSizes().asArrayRef()); | ||
SmallVector<int32_t, 3> workgroupCounts( | ||
target.getWgp().getMaxWorkgroupCounts().asArrayRef()); | ||
|
||
std::optional<SmallVector<int64_t>> staticWorkgroupSize = | ||
getWorkgroupSize(funcOp); | ||
|
||
// Late in codegen, we've reconciled the workgroup size onto the export op. | ||
if (std::optional<IREE::HAL::ExecutableExportOp> exportOp = | ||
getEntryPoint(funcOp)) { | ||
if (std::optional<ArrayAttr> exportWorkgroupSize = | ||
exportOp->getWorkgroupSize()) { | ||
staticWorkgroupSize = | ||
llvm::map_to_vector(exportWorkgroupSize->getAsRange<IntegerAttr>(), | ||
[](IntegerAttr a) { return a.getInt(); }); | ||
} | ||
} | ||
|
||
if (staticWorkgroupSize) { | ||
// Target info with no workgroup sizes gives a 0-length array, hence no | ||
// zip_equal. | ||
for (auto [size, staticSize] : | ||
llvm::zip(workgroupSizes, *staticWorkgroupSize)) { | ||
size = staticSize; | ||
} | ||
} | ||
SmallVector<int64_t> staticWorkgroupCounts = getStaticNumWorkgroups(funcOp); | ||
assert(staticWorkgroupCounts.size() <= 3 && | ||
"workgroup counts are 3D at most"); | ||
for (auto [count, staticCount] : | ||
llvm::zip(workgroupCounts, staticWorkgroupCounts)) { | ||
if (staticCount != ShapedType::kDynamic) { | ||
count = staticCount; | ||
} | ||
} | ||
|
||
applyBounds(funcOp, workgroupSizes, workgroupCounts); | ||
} | ||
}; | ||
} // namespace | ||
|
||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
122 changes: 122 additions & 0 deletions
122
compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
// RUN: iree-opt %s --split-input-file \ | ||
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-gpu-propagate-dispatch-size-bounds)))))" \ | ||
// RUN: | FileCheck %s | ||
|
||
// Note: not the real target definition, missing types | ||
#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", | ||
wgp = <compute = fp32, | ||
storage = b32, | ||
subgroup = arithmetic, | ||
dot = none, mma = [], | ||
subgroup_size_choices = [32, 64], | ||
max_workgroup_sizes = [1024, 1024, 1024], | ||
max_thread_count_per_workgroup = 1024, | ||
max_workgroup_memory_bytes = 65536, | ||
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}> | ||
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]> | ||
|
||
hal.executable private @static { | ||
hal.executable.variant public @rocm_hsaco_fb target(#executable_target) { | ||
hal.executable.export public @static ordinal(0) layout(#pipeline_layout) attributes {workgroup_size = [64 : index, 2 : index, 1 : index]} { | ||
^bb0(%arg0: !hal.device): | ||
%c32 = arith.constant 32 : index | ||
%c8 = arith.constant 8 : index | ||
%c1 = arith.constant 1 : index | ||
hal.return %c32, %c8, %c1 : index, index, index | ||
} | ||
builtin.module { | ||
// CHECK-LABEL: func.func @static | ||
func.func @static() { | ||
// CHECK: gpu.thread_id x upper_bound 64 | ||
// CHECK: gpu.thread_id y upper_bound 2 | ||
// CHECK: gpu.thread_id z upper_bound 1 | ||
%thread_id_x = gpu.thread_id x | ||
%thread_id_y = gpu.thread_id y | ||
%thread_id_z = gpu.thread_id z | ||
|
||
// CHECK: hal.interface.workgroup.size[0] upper_bound 64 | ||
// CHECK: hal.interface.workgroup.size[1] upper_bound 2 | ||
// CHECK: hal.interface.workgroup.size[2] upper_bound 1 | ||
%workgroup_size_x = hal.interface.workgroup.size[0] : index | ||
%workgroup_size_y = hal.interface.workgroup.size[1] : index | ||
%workgroup_size_z = hal.interface.workgroup.size[2] : index | ||
|
||
// CHECK: hal.interface.workgroup.id[0] upper_bound 32 | ||
// CHECK: hal.interface.workgroup.id[1] upper_bound 8 | ||
// CHECK: hal.interface.workgroup.id[2] upper_bound 1 | ||
%workgroup_id_x = hal.interface.workgroup.id[0] : index | ||
%workgroup_id_y = hal.interface.workgroup.id[1] : index | ||
%workgroup_id_z = hal.interface.workgroup.id[2] : index | ||
|
||
// CHECK: hal.interface.workgroup.count[0] upper_bound 32 | ||
// CHECK: hal.interface.workgroup.count[1] upper_bound 8 | ||
// CHECK: hal.interface.workgroup.count[2] upper_bound 1 | ||
%workgroup_conut_x = hal.interface.workgroup.count[0] : index | ||
%workgroup_count_y = hal.interface.workgroup.count[1] : index | ||
%workgroup_count_z = hal.interface.workgroup.count[2] : index | ||
|
||
return | ||
} | ||
} | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", | ||
{iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", | ||
wgp = <compute = fp32, | ||
storage = b32, | ||
subgroup = arithmetic, | ||
dot = none, mma = [], | ||
subgroup_size_choices = [32, 64], | ||
max_workgroup_sizes = [1024, 1024, 1024], | ||
max_thread_count_per_workgroup = 1024, | ||
max_workgroup_memory_bytes = 65536, | ||
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}> | ||
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]> | ||
|
||
hal.executable private @dynamic { | ||
hal.executable.variant public @rocm_hsaco_fb target(#executable_target) { | ||
hal.executable.export public @dynamic ordinal(0) layout(#pipeline_layout) { | ||
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index): | ||
%count_x = affine.apply affine_map<()[s0] -> (s0 ceildiv 32)>()[%arg1] | ||
%count_y = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%arg2] | ||
%count_z = arith.constant 1 : index | ||
hal.return %count_x, %count_y, %count_z : index, index, index | ||
} | ||
builtin.module { | ||
func.func @dynamic() { | ||
// CHECK: gpu.thread_id x upper_bound 1024 | ||
// CHECK: gpu.thread_id y upper_bound 1024 | ||
// CHECK: gpu.thread_id z upper_bound 1024 | ||
%thread_id_x = gpu.thread_id x | ||
%thread_id_y = gpu.thread_id y | ||
%thread_id_z = gpu.thread_id z | ||
|
||
// CHECK: hal.interface.workgroup.size[0] upper_bound 1024 | ||
// CHECK: hal.interface.workgroup.size[1] upper_bound 1024 | ||
// CHECK: hal.interface.workgroup.size[2] upper_bound 1024 | ||
%workgroup_size_x = hal.interface.workgroup.size[0] : index | ||
%workgroup_size_y = hal.interface.workgroup.size[1] : index | ||
%workgroup_size_z = hal.interface.workgroup.size[2] : index | ||
|
||
// CHECK: hal.interface.workgroup.id[0] upper_bound 2147483647 | ||
// CHECK: hal.interface.workgroup.id[1] upper_bound 2147483647 | ||
// CHECK: hal.interface.workgroup.id[2] upper_bound 1 | ||
%workgroup_id_x = hal.interface.workgroup.id[0] : index | ||
%workgroup_id_y = hal.interface.workgroup.id[1] : index | ||
%workgroup_id_z = hal.interface.workgroup.id[2] : index | ||
|
||
// CHECK: hal.interface.workgroup.count[0] upper_bound 2147483647 | ||
// CHECK: hal.interface.workgroup.count[1] upper_bound 2147483647 | ||
// CHECK: hal.interface.workgroup.count[2] upper_bound 1 | ||
%workgroup_conut_x = hal.interface.workgroup.count[0] : index | ||
%workgroup_count_y = hal.interface.workgroup.count[1] : index | ||
%workgroup_count_z = hal.interface.workgroup.count[2] : index | ||
|
||
return | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.