Skip to content

Commit

Permalink
[Codegen][GPU] Use bufferization.alloc_tensor for gpu.shuffle_tensor …
Browse files Browse the repository at this point in the history
…destination (iree-org#17940)

The destination for the gpu.shuffle_tensor op will always end up needing
shared memory allocations. When the destination is left as a
tensor.empty op, it can potentially be CSEd with other
gpu.shuffle_tensor destinations. This PR creates a
bufferization.alloc_tensor when generating gpu.shuffle_tensor ops
instead, which will not be CSEd.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 authored Jul 19, 2024
1 parent cfc79ea commit 69900ee
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ module attributes { transform.with_named_sequence } {
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>

// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32>
// CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK-DAG: %[[OUTID0:.+]] = affine.apply #[[$MAP]](%[[IDX]])
// CHECK-DAG: %[[OUTID1:.+]] = affine.apply #[[$MAP]](%[[IDY]])
Expand All @@ -54,7 +55,7 @@ module attributes { transform.with_named_sequence } {
// CHECK: %[[INSLICE0:.+]] = tensor.extract_slice %[[ARG0]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
// CHECK: %[[INSLICE1:.+]] = tensor.extract_slice %[[EMPTY]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[INSLICE0]] : tensor<2x128xf32>) outs(%[[INSLICE1]] : tensor<2x128xf32>) -> tensor<2x128xf32>
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.shuffle_tensor %[[COPY]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] to %[[EMPTY]]
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.shuffle_tensor %[[COPY]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] to %[[ALLOC]]
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
// CHECK: iree_gpu.yield %[[SLICE]]
Expand All @@ -72,8 +73,9 @@ module attributes { transform.with_named_sequence } {
#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 16)>
module {
func.func @fuse_forall(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
%2 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %arg1) -> (tensor<128x128xf32>) {
func.func @fuse_forall(%arg0: tensor<128x128xf32>) -> tensor<128x128xf32> {
%empty = tensor.empty() : tensor<128x128xf32>
%2 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %empty) -> (tensor<128x128xf32>) {
%4 = affine.apply #map(%arg5)
%extracted_slice = tensor.extract_slice %arg0[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
%extracted_slice_0 = tensor.extract_slice %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
Expand All @@ -82,7 +84,7 @@ module {
tensor.parallel_insert_slice %5 into %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<2x128xf32> into tensor<128x128xf32>
}
} {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
%3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %arg1) -> (tensor<128x128xf32>) {
%3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %empty) -> (tensor<128x128xf32>) {
%6 = affine.apply #map1(%arg5)
%7 = affine.apply #map1(%arg6)
%extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
Expand Down Expand Up @@ -111,10 +113,11 @@ module attributes { transform.with_named_sequence } {

// CHECK-LABEL: func @fuse_forall
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<128x128xf32>

// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[ARG1]]) -> (tensor<128x128xf32>) {
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.shuffle_tensor %{{.*}} to %[[ARG1]]
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32>
// CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.shuffle_tensor %{{.*}} to %[[ALLOC]]
// CHECK: } : tensor<2x128xf32> -> tensor<128x128xf32> -> tensor<16x16xf32>
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}

Expand All @@ -123,8 +126,9 @@ module attributes { transform.with_named_sequence } {
#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 16)>
module {
func.func @fuse_forall_with_reshape(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
%2 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %arg1) -> (tensor<128x128xf32>) {
func.func @fuse_forall_with_reshape(%arg0: tensor<128x128xf32>) -> tensor<128x128xf32> {
%empty = tensor.empty() : tensor<128x128xf32>
%2 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %empty) -> (tensor<128x128xf32>) {
%4 = affine.apply #map(%arg5)
%extracted_slice = tensor.extract_slice %arg0[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
%extracted_slice_0 = tensor.extract_slice %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
Expand All @@ -134,7 +138,7 @@ module {
}
} {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
%expand = tensor.expand_shape %2 [[0, 1], [2]] output_shape [2, 64, 128] : tensor<128x128xf32> into tensor<2x64x128xf32>
%3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %arg1) -> (tensor<128x128xf32>) {
%3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %empty) -> (tensor<128x128xf32>) {
%6 = affine.apply #map1(%arg5)
%7 = affine.apply #map1(%arg6)
%extracted_slice_0 = tensor.extract_slice %expand[0, %6, %7] [1, 16, 16] [1, 1, 1] : tensor<2x64x128xf32> to tensor<16x16xf32>
Expand Down Expand Up @@ -163,10 +167,11 @@ module attributes { transform.with_named_sequence } {

// CHECK-LABEL: func @fuse_forall_with_reshape
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<128x128xf32>

// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[ARG1]]) -> (tensor<128x128xf32>) {
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.shuffle_tensor %{{.*}} to %[[ARG1]]
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32>
// CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.shuffle_tensor %{{.*}} to %[[ALLOC]]
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[INTERMEDIATE]] {{\[}}[0, 1], [2]{{\]}} output_shape [2, 64, 128]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[EXPAND]][0, %{{.*}}, %{{.*}}] [1, 16, 16] [1, 1, 1] : tensor<2x64x128xf32> to tensor<16x16xf32>
Expand All @@ -179,8 +184,9 @@ module attributes { transform.with_named_sequence } {
#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0, d1) -> (d1 + d0 * 16)>
module {
func.func @fuse_thread_forall_with_warp_and_lane(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
%2 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %arg1) -> (tensor<128x128xf32>) {
func.func @fuse_thread_forall_with_warp_and_lane(%arg0: tensor<128x128xf32>) -> tensor<128x128xf32> {
%empty = tensor.empty() : tensor<128x128xf32>
%2 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %empty) -> (tensor<128x128xf32>) {
%4 = affine.apply #map(%arg5)
%extracted_slice = tensor.extract_slice %arg0[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
%extracted_slice_0 = tensor.extract_slice %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
Expand All @@ -189,7 +195,7 @@ module {
tensor.parallel_insert_slice %5 into %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<2x128xf32> into tensor<128x128xf32>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%3 = scf.forall (%arg9, %arg10) in (2, 2) shared_outs(%arg8 = %arg1) -> (tensor<128x128xf32>) {
%3 = scf.forall (%arg9, %arg10) in (2, 2) shared_outs(%arg8 = %empty) -> (tensor<128x128xf32>) {
%extracted_slice_2 = tensor.extract_slice %arg8[%arg9, %arg10] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
%9 = scf.forall (%arg5, %arg6) in (4, 4) shared_outs(%arg7 = %extracted_slice_2) -> (tensor<64x64xf32>) {
%6 = affine.apply #map1(%arg5, %arg9)
Expand Down Expand Up @@ -224,16 +230,17 @@ module attributes { transform.with_named_sequence } {

// CHECK-LABEL: func @fuse_thread_forall_with_warp_and_lane
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<128x128xf32>

// CHECK: scf.forall (%[[W_IDX:.+]], %[[W_IDY:.+]]) in (2, 2) shared_outs(%[[INIT:.+]] = %[[ARG1]]) -> (tensor<128x128xf32>) {
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32>
// CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[W_IDX:.+]], %[[W_IDY:.+]]) in (2, 2) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK: scf.forall (%[[L_IDX:.+]], %[[L_IDY:.+]]) in (4, 4) {{.*}} -> (tensor<64x64xf32>)
// CHECK-DAG: %[[FLAT_ID:.+]] = affine.apply #[[$MAP3]](%[[L_IDY]], %[[L_IDX]], %[[W_IDX]], %[[W_IDY]])
// CHECK-DAG: %[[IDS:.+]]:2 = affine.delinearize_index %[[FLAT_ID]] into (%c64, %c1) : index, index
// CHECK-DAG: %[[IDX:.+]] = affine.apply #[[$MAP4]](%[[IDS]]#0)
// CHECK: %[[COPY:.+]] = linalg.copy
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.shuffle_tensor
// CHECK-SAME: %[[COPY]][%[[IDX]], %[[IDS]]#1] [2, 128] [1, 1] to %[[ARG1]]
// CHECK-SAME: %[[COPY]][%[[IDX]], %[[IDS]]#1] [2, 128] [1, 1] to %[[ALLOC]]
// CHECK: } : tensor<2x128xf32> -> tensor<128x128xf32> -> tensor<16x16xf32>
// CHECK: } {mapping = [#iree_gpu.lane_id<1>, #iree_gpu.lane_id<0>]}
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:DestinationStyleOpInterface",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ iree_cc_library(
MLIRAffineDialect
MLIRAffineUtils
MLIRArithDialect
MLIRBufferizationDialect
MLIRDestinationStyleOpInterface
MLIRFuncDialect
MLIRFunctionInterfaces
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def FuseAndHoistParallelLoopsPass :
let summary = "Greedily fuses and hoists parallel loops.";
let dependentDialects = [
"::mlir::affine::AffineDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect"
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
"::mlir::bufferization::BufferizationDialect"
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
Expand Down Expand Up @@ -107,16 +109,32 @@ static LogicalResult compareWorkerCounts(scf::ForallOp producer,
return success();
}

static void replaceConsumerChain(RewriterBase &rewriter, Location loc,
Value source,
tensor::ParallelInsertSliceOp parallelInsert,
SmallVector<Operation *> consumerChain) {
static LogicalResult
replaceConsumerChain(RewriterBase &rewriter, Location loc, Value source,
tensor::ParallelInsertSliceOp parallelInsert,
SmallVector<Operation *> consumerChain) {
auto extractSlice = cast<tensor::ExtractSliceOp>(consumerChain.back());
OpBuilder::InsertionGuard g(rewriter);
Value shuffleDest = parallelInsert.getDest();
auto empty = shuffleDest.getDefiningOp<tensor::EmptyOp>();
if (!empty) {
return failure();
}

{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(empty);
Attribute sharedMemoryAddrSpace = gpu::AddressSpaceAttr::get(
rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
auto allocTensor = rewriter.create<bufferization::AllocTensorOp>(
empty->getLoc(), empty->getResultTypes()[0], empty.getDynamicSizes());
allocTensor.setMemorySpaceAttr(sharedMemoryAddrSpace);
shuffleDest = allocTensor.getResult();
}
auto shuffleOp = rewriter.create<IREE::GPU::ShuffleTensorOp>(
loc, extractSlice.getType(), parallelInsert.getSource(),
parallelInsert.getDest(), parallelInsert.getMixedOffsets(),
parallelInsert.getMixedSizes(), parallelInsert.getMixedStrides());
loc, extractSlice.getType(), parallelInsert.getSource(), shuffleDest,
parallelInsert.getMixedOffsets(), parallelInsert.getMixedSizes(),
parallelInsert.getMixedStrides());
rewriter.setInsertionPointToStart(shuffleOp.getBody());
auto terminator =
rewriter.create<IREE::GPU::YieldOp>(loc, extractSlice.getResult());
Expand All @@ -127,6 +145,7 @@ static void replaceConsumerChain(RewriterBase &rewriter, Location loc,
->replaceUsesOfWith(source, shuffleOp.getBody()->getArgument(0));
rewriter.replaceAllUsesExcept(extractSlice.getResult(), shuffleOp,
terminator);
return success();
}

LogicalResult fuseForallIntoSlice(RewriterBase &rewriter,
Expand Down Expand Up @@ -213,8 +232,10 @@ LogicalResult fuseForallIntoSlice(RewriterBase &rewriter,
auto parallelInsert =
cast<tensor::ParallelInsertSliceOp>(*terminator.getYieldingOps().begin());

replaceConsumerChain(rewriter, loc, producer.getResult(0), parallelInsert,
consumerChain);
if (failed(replaceConsumerChain(rewriter, loc, producer.getResult(0),
parallelInsert, consumerChain))) {
return failure();
}

rewriter.eraseOp(parallelInsert);
rewriter.eraseOp(terminator);
Expand Down

0 comments on commit 69900ee

Please sign in to comment.