Skip to content

Commit

Permalink
[Codegen] Add destination fusion to fuse_and_hoist pass (iree-org#17517)
Browse files Browse the repository at this point in the history
This allows fusion of ops like `linalg.fill` on the destination of after
loop fusion + hoisting, as in typical flows the destination is not
fusable to the outer serial loop, but is fusable (distributable) with
the distributed loop.
  • Loading branch information
qedawkins authored May 29, 2024
1 parent 26e4c6b commit 2c59505
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,54 @@ struct FuseForalls final : OpRewritePattern<tensor::ExtractSliceOp> {
}
};

struct FuseTileableDestinationProducers final
: OpRewritePattern<scf::ForallOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(scf::ForallOp forallOp,
PatternRewriter &rewriter) const override {
TilingInterface tileableProducer;
tensor::ExtractSliceOp sliceOp;
for (auto iterArg : forallOp.getRegionIterArgs()) {
for (auto user : iterArg.getUsers()) {
sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
if (sliceOp) {
break;
}
}
if (!sliceOp) {
continue;
}
tileableProducer = forallOp.getTiedLoopInit(iterArg)
->get()
.getDefiningOp<TilingInterface>();
if (tileableProducer) {
break;
}
}
if (!tileableProducer) {
return failure();
}

SmallVector<LoopLikeOpInterface> loops = {forallOp};
rewriter.startOpModification(forallOp);
std::optional<scf::SCFFuseProducerOfSliceResult> fusionResult =
mlir::scf::tileAndFuseProducerOfSlice(rewriter, sliceOp, loops);
if (!fusionResult) {
return failure();
}
rewriter.finalizeOpModification(forallOp);
return success();
}
};

void FuseAndHoistParallelLoopsPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

// These two patterns are run to a fixed point, allowing fusion within
// potentially nested loops, hoisting from said loops, and continued fusion.
patterns.add<FuseForalls>(context);
patterns.add<FuseTileableDestinationProducers>(context);
populateForallLoopHoistingPattern(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,77 @@ module {
// CHECK: scf.forall.in_parallel
// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]]
// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]]

// -----

#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 4)>
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
#map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
#map4 = affine_map<(d0) -> (d0 * 16)>
module {
func.func @forall_fuse_then_hoist_with_fill() {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
%empty = tensor.empty() : tensor<128x128xf32>
%cst = arith.constant 0.0 : f32
%5 = linalg.fill ins(%cst : f32) outs(%empty : tensor<128x128xf32>) -> tensor<128x128xf32>
%6 = tensor.empty() : tensor<128x4xf16>
%7 = tensor.empty() : tensor<4x128xf16>
%8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<128x128xf32>) {
%9 = scf.forall (%arg2, %arg3) in (64, 1) shared_outs(%arg4 = %6) -> (tensor<128x4xf16>) {
%12 = affine.apply #map(%arg2)
%13 = affine.apply #map1(%arg3)
%14 = affine.apply #map(%arg2)
%15 = affine.apply #map2(%arg3)[%arg0]
%extracted_slice = tensor.extract_slice %3[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16>
%extracted_slice_0 = tensor.extract_slice %arg4[%12, %13] [2, 4] [1, 1] : tensor<128x4xf16> to tensor<2x4xf16>
%16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16>
}
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
%10 = scf.forall (%arg2, %arg3) in (2, 32) shared_outs(%arg4 = %7) -> (tensor<4x128xf16>) {
%12 = affine.apply #map(%arg2)
%13 = affine.apply #map1(%arg3)
%14 = affine.apply #map3(%arg2)[%arg0]
%15 = affine.apply #map1(%arg3)
%extracted_slice = tensor.extract_slice %4[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16>
%extracted_slice_0 = tensor.extract_slice %arg4[%12, %13] [2, 4] [1, 1] : tensor<4x128xf16> to tensor<2x4xf16>
%16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<4x128xf16>
}
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
%11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) {
%12 = affine.apply #map4(%arg2)
%13 = affine.apply #map4(%arg3)
%extracted_slice = tensor.extract_slice %9[%12, 0] [16, 4] [1, 1] : tensor<128x4xf16> to tensor<16x4xf16>
%extracted_slice_0 = tensor.extract_slice %10[0, %13] [4, 16] [1, 1] : tensor<4x128xf16> to tensor<4x16xf16>
%extracted_slice_1 = tensor.extract_slice %arg4[%12, %13] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
%14 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<16x4xf16>, tensor<4x16xf16>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
}
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
scf.yield %11 : tensor<128x128xf32>
}
flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
return
}
}

// CHECK-LABEL: func @forall_fuse_then_hoist_with_fill
// CHECK: %[[OUTER_PARALLEL:.+]] = scf.forall
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%{{.*}} = %[[FILL]])
// CHECK: scf.yield {{.*}} : tensor<16x16xf32>
// CHECK: scf.forall.in_parallel
// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]]
// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]]

0 comments on commit 2c59505

Please sign in to comment.