Skip to content

Commit

Permalink
[Codegen][GPU] Allow serial tiling of online_attention op (iree-org#1…
Browse files Browse the repository at this point in the history
…7702)

This PR adds support to do reduction tiling on online_attention op in
GPUTileToSerialLoops pass
  • Loading branch information
Groverkss authored Jun 20, 2024
1 parent 90f29a6 commit 12d43e8
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
22 changes: 13 additions & 9 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,16 @@ class TileConsumerAndFuseInputProducer final
// Make sure we have a PartitionableLoopInterface op here and query the tile
// sizes from the partitionable loops.
auto plOp = dyn_cast<PartitionableLoopsInterface>(*op);
if (!plOp)
return failure();
if (!plOp) {
return rewriter.notifyMatchFailure(
op, "Op does not implement PartitionableLoopsInterface");
}
auto partitionedLoops = plOp.getPartitionableLoops(kNumMaxParallelDims);
SmallVector<int64_t> tileSizes = getTileSizes(op, 0);
if (tileSizes.empty())
return failure();
if (tileSizes.empty()) {
return rewriter.notifyMatchFailure(
op, "Op does not have configuration to get tile_sizes from");
}
// Mask out non reduction dimensions.
for (unsigned depth : partitionedLoops) {
if (depth < tileSizes.size())
Expand All @@ -73,7 +77,7 @@ class TileConsumerAndFuseInputProducer final
tileSizes.resize(op.getLoopIteratorTypes().size(), 0);

if (llvm::all_of(tileSizes, [](int64_t s) { return s == 0; })) {
return failure();
return rewriter.notifyMatchFailure(op, "No dimensions are tiled");
}

// Tile the current op and fuse its immediate input operands.
Expand Down Expand Up @@ -137,14 +141,14 @@ class TileConsumerAndFuseInputProducer final
auto sliceOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp)
continue;
auto linalgOp = sliceOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
auto tilingOp = sliceOp.getSource().getDefiningOp<TilingInterface>();
if (!tilingOp)
continue;
// Restrict to fully parallel linalg ops for now for simplicity.
// Restrict to fully parallel ops for now for simplicity.
auto isParallel = [](utils::IteratorType it) {
return linalg::isParallelIterator(it);
};
if (llvm::all_of(linalgOp.getIteratorTypesArray(), isParallel)) {
if (llvm::all_of(tilingOp.getLoopIteratorTypes(), isParallel)) {
candidates.push_back(sliceOp);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,37 @@ func.func @conv() attributes {translation_info = #iree_codegen.translation_info<
// COALESCE_LOOPS: scf.for
// COALESCE_LOOPS-NOT: scf.for
// COALESCE_LOOPS: return

#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>

#config = #iree_codegen.lowering_config<tile_sizes = [[64, 0, 0, 32, 0]]>

func.func @online_attention(%query: tensor<192x1024x64xf16>,
%key: tensor<192x1024x64xf16>,
%value: tensor<192x1024x64xf16>,
%output: tensor<192x1024x64xf32>,
%max: tensor<192x1024xf32>,
%sum: tensor<192x1024xf32>)
-> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
%scale = arith.constant 1.0 : f16

%out:3 = iree_linalg_ext.online_attention
{ indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR],
lowering_config = #config }
ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>

return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
}

// Just check if the operation gets tiled. The actual tiling verification tests
// are in online_attention tiling interface tests.
// CHECK-LABEL: func.func @online_attention
// CHECK: scf.for
// CHECK: iree_linalg_ext.online_attention
// CHECK: scf.yield
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ void registerPartitionableLoopsInterfaceModels(DialectRegistry &registry) {
IREE::LinalgExt::WinogradOutputTransformOp>>(*ctx);
IREE::LinalgExt::AttentionOp::attachInterface<
AllParallelAsPartitionableLoops<IREE::LinalgExt::AttentionOp>>(*ctx);
IREE::LinalgExt::OnlineAttentionOp::attachInterface<
AllParallelAsPartitionableLoops<IREE::LinalgExt::OnlineAttentionOp>>(
*ctx);
});
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
tensor::PackOp::attachInterface<
Expand Down

0 comments on commit 12d43e8

Please sign in to comment.