Skip to content

Commit

Permalink
[Codegen] Bubble up Transpose attention V and try fuse with others be…
Browse files Browse the repository at this point in the history
…fore attention (iree-org#19250)

Flash Attention transpose_V variant is significantly faster than the
non-transpose_V variant. This is due to many matmul intrinsics being
mmtb by default. Hence, doing FA transpose_V will allow for better/more
contiguous reads from shared memory to register, improving the attention
performance quite a bit.

This PR exposes the attention_transposeV form by generating a
linalg.transpose on the V during bubbling up of transpose S.T we can
give the graph some opportunities to fuse the transpose-V to it's
producer. I have also confirmed that if we do not find any producer, the
transpose will indeed fuse back with the attenionOp. Hence worse case,
we will get same perf as before this PR.

Additionally, we modify elementwise op fusion to try fuse transpose with
other ops before letting it get fused back into attention.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu authored Nov 26, 2024
1 parent 5708d42 commit 41115bb
Show file tree
Hide file tree
Showing 8 changed files with 314 additions and 17 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ jobs:
--goldentime-rocm-unet-ms 419.0 \
--goldentime-rocm-clip-ms 18.5 \
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-unet 1602 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2280000 \
Expand All @@ -238,21 +238,21 @@ jobs:
run: |
source ${VENV_DIR}/bin/activate
pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \
--goldentime-rocm-e2e-ms 372.0 \
--goldentime-rocm-unet-ms 95.0 \
--goldentime-rocm-e2e-ms 330.0 \
--goldentime-rocm-unet-ms 80.0 \
--goldentime-rocm-clip-ms 15.5 \
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-unet 1602 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
--goldentime-rocm-punet-int8-fp16-ms 55 \
--goldendispatch-rocm-punet-int8-fp16 1284 \
--goldentime-rocm-punet-int8-fp16-ms 53 \
--goldendispatch-rocm-punet-int8-fp16 1424 \
--goldensize-rocm-punet-int8-fp16-bytes 2560000 \
--goldentime-rocm-punet-int8-fp8-ms 59 \
--goldendispatch-rocm-punet-int8-fp8 1564 \
--goldentime-rocm-punet-int8-fp8-ms 53 \
--goldendispatch-rocm-punet-int8-fp8 1704 \
--goldensize-rocm-punet-int8-fp8-bytes 2800000 \
--rocm-chip gfx942 \
--log-cli-level=info \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,41 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran
transform.yield %cont, %config : !transform.any_op, !transform.any_param
}


// Variant of matmul_like_Bx20x1024x64x1280_i8xi8xi32 from Transposed-V.
transform.named_sequence @match_matmul_like_Bx20x64x1024x1280_i8xi8xi32(%cont: !transform.any_op {transform.readonly})
-> (!transform.any_op, !transform.any_param) {
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont {
^bb0(%lhs: tensor<?x1024x1280xi8>, %rhs: tensor<20x64x1280xi8>, %out: tensor<?x20x64x1024xi32>):
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
ins(%lhs, %rhs : tensor<?x1024x1280xi8>, tensor<20x64x1280xi8>)
outs(%out : tensor<?x20x64x1024xi32>) {
^bb0(%in: i8, %in_0: i8, %acc: i32):
%18 = arith.extsi %in : i8 to i32
%19 = arith.extsi %in_0 : i8 to i32
%20 = arith.muli %18, %19 : i32
%21 = arith.addi %acc, %20 : i32
linalg.yield %21 : i32
} -> tensor<?x20x64x1024xi32>
} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
subgroup_m_count = 2, subgroup_n_count = 2,
reduction = [0, 0, 0, 0, 128],
workgroup = [1, 1, 160, 64, 0]}>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
workgroup_size = [256, 1, 1] subgroup_size = 64,
{gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true,
reorder_workgroups_strategy = <Transpose>>
}>
> -> !transform.any_param
transform.yield %cont, %config : !transform.any_op, !transform.any_param
}

transform.named_sequence @match_matmul_like_Bx20x64x64x2048_i8xi8xi32(%cont: !transform.any_op {transform.readonly})
-> (!transform.any_op, !transform.any_param) {
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont {
Expand Down Expand Up @@ -239,6 +274,38 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran
transform.yield %cont, %config : !transform.any_op, !transform.any_param
}

// Variant of matmul_like_Bx20x64x64x2048_i8xi8xi32 from Transposed-V.
transform.named_sequence @match_matmul_like_Bx20x64x64x2048_transposev_i8xi8xi32(%cont: !transform.any_op {transform.readonly})
-> (!transform.any_op, !transform.any_param) {
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont {
^bb0(%lhs: tensor<?x64x2048xi8>, %rhs: tensor<20x64x2048xi8>, %out: tensor<?x20x64x64xi32>):
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
ins(%lhs, %rhs : tensor<?x64x2048xi8>, tensor<20x64x2048xi8>)
outs(%out : tensor<?x20x64x64xi32>) {
^bb0(%in: i8, %in_0: i8, %acc: i32):
%18 = arith.extsi %in : i8 to i32
%19 = arith.extsi %in_0 : i8 to i32
%20 = arith.muli %18, %19 : i32
%21 = arith.addi %acc, %20 : i32
linalg.yield %21 : i32
} -> tensor<?x20x64x64xi32>
} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
subgroup_m_count = 2, subgroup_n_count = 1,
reduction = [0, 0, 0, 0, 128],
workgroup = [1, 1, 320, 32, 0]}>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
workgroup_size = [128, 1, 1] subgroup_size = 64,
{gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>}>
> -> !transform.any_param
transform.yield %cont, %config : !transform.any_op, !transform.any_param
}

transform.named_sequence @match_matmul_like_Bx10x4096x64x640_i8xi8xi32(%cont: !transform.any_op {transform.readonly})
-> (!transform.any_op, !transform.any_param) {
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont {
Expand Down Expand Up @@ -302,6 +369,10 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran
, @match_matmul_like_Bx10x4096x64x640_i8xi8xi32 -> @apply_op_config
, @match_matmul_like_Bx20x64x64x2048_i8xi8xi32 -> @apply_op_config

// Transpose-V generated contraction.
, @match_matmul_like_Bx20x64x1024x1280_i8xi8xi32 -> @apply_op_config
, @match_matmul_like_Bx20x64x64x2048_transposev_i8xi8xi32 -> @apply_op_config

// TUNING_MATCH_END DO NOT REMOVE
: (!transform.any_op) -> (!transform.any_op)
transform.yield
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ void populateFuseLinalgExtOpsWithTransposes(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFusionFn);

/// Bubble up transpose-like ops from LinalgExt ops (only `AttentionOp`
/// supported).
void populateBubbleTransposeFromLinalgExtOps(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFusionFn);

/// Helper struct to hold the results of collapsing an operation.
struct CollapseResult {
SmallVector<Value> results;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -101,6 +102,103 @@ struct FuseTransposeWithAttentionOp final
private:
linalg::ControlFusionFn controlFn;
};

// Bubbles transpose-V out of attention to expose the more performant
// attention-transposeV.
struct BubbleTransposeVFromAttentionOp
: public OpRewritePattern<LinalgExt::AttentionOp> {
BubbleTransposeVFromAttentionOp(MLIRContext *context,
linalg::ControlFusionFn controlFn,
PatternBenefit benefit = 1)
: OpRewritePattern<LinalgExt::AttentionOp>(context, benefit),
controlFn(controlFn) {}

LogicalResult matchAndRewrite(LinalgExt::AttentionOp attentionOp,
PatternRewriter &rewriter) const override {
// Only checking for V because we are only bubbling transpose-V.
OpOperand *valueOpOperand = &attentionOp.getValueMutable();
if (controlFn && !controlFn(valueOpOperand)) {
return rewriter.notifyMatchFailure(
attentionOp, "Expected attentionOp and producer of V to be non-null "
"and outside dispatch.");
}
// Extract Attention indexing information.
AffineMap qMap = attentionOp.getQueryMap();
AffineMap kMap = attentionOp.getKeyMap();
AffineMap vMap = attentionOp.getValueMap();
AffineMap oMap = attentionOp.getOutputMap();
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(qMap, kMap, vMap, oMap);
if (failed(maybeOpInfo)) {
return failure();
}

// Only handle single dim for K2 and N for now.
if (maybeOpInfo->getK2Dims().size() != 1 ||
maybeOpInfo->getNDims().size() != 1) {
return failure();
}
// Check that V has standard map/non transposed V.
AffineExpr k2Dim =
rewriter.getAffineDimExpr(maybeOpInfo->getK2Dims().back());
AffineExpr nDim = rewriter.getAffineDimExpr(maybeOpInfo->getNDims().back());
int64_t vRank = vMap.getNumResults();
// TODO: This check is quite conservative, in the future we should simply
// do vMap.getResultPosition(k2Dim) > vMap.getResultPosition(nDim).
if (vMap.getResult(vRank - 1) != nDim ||
vMap.getResult(vRank - 2) != k2Dim) {
return failure();
}

// Get dimension positions to prepare for transpose.
std::optional<int64_t> maybeK2Pos = vMap.getResultPosition(k2Dim);
std::optional<int64_t> maybeNPos = vMap.getResultPosition(nDim);
assert(maybeK2Pos.has_value() && maybeNPos.has_value() &&
"Expected K2 dim and N dim to be in V-map.");
int64_t k2Pos = maybeK2Pos.value();
int64_t nPos = maybeNPos.value();
SmallVector<int64_t> perm = llvm::to_vector(llvm::seq<int64_t>(0, vRank));
std::swap(perm[k2Pos], perm[nPos]);

// Expose transposeOp for V.
Location loc = attentionOp.getLoc();
Value value = attentionOp.getValue();
auto valueType = dyn_cast<ShapedType>(value.getType());
auto valueElType = valueType.getElementType();
SmallVector<OpFoldResult> transVShape =
tensor::getMixedSizes(rewriter, loc, value);
applyPermutationToVector(transVShape, perm);
Value initTransV =
rewriter.create<tensor::EmptyOp>(loc, transVShape, valueElType)
.getResult();
Value transposeV =
rewriter.create<linalg::TransposeOp>(loc, value, initTransV, perm)
->getResult(0);

// Generate transpose V map.
SmallVector<AffineExpr> newExprs =
applyPermutation(vMap.getResults(), perm);
AffineMap transposedVMap =
AffineMap::get(vMap.getNumDims(), vMap.getNumSymbols(), newExprs,
rewriter.getContext());

// Modify attention to have transposed V inputs and mapping.
int64_t valueIndex = valueOpOperand->getOperandNumber();
rewriter.modifyOpInPlace(attentionOp, [&]() {
SmallVector<AffineMap> newIndexingMaps =
attentionOp.getIndexingMapsArray();
newIndexingMaps[valueIndex] = transposedVMap;
attentionOp.setIndexingMapsAttr(
rewriter.getAffineMapArrayAttr(newIndexingMaps));
attentionOp.setOperand(valueIndex, transposeV);
});
return success();
}

private:
linalg::ControlFusionFn controlFn;
};

} // namespace

void populateFuseLinalgExtOpsWithTransposes(
Expand All @@ -110,4 +208,11 @@ void populateFuseLinalgExtOpsWithTransposes(
controlFusionFn);
}

void populateBubbleTransposeFromLinalgExtOps(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFusionFn) {
patterns.add<BubbleTransposeVFromAttentionOp>(patterns.getContext(),
controlFusionFn);
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ struct GatherFusionPattern final : public OpRewritePattern<tensor::ExtractOp> {
void ElementwiseOpFusionPass::runOnOperation() {
MLIRContext *context = &getContext();

RewritePatternSet fusionPatterns(context);
// Only fuse operations where all uses of the producer are generic
// operations. If an operation is used in a named op, it will be computed
// anyway, so the consumers can just use that value.
Expand Down Expand Up @@ -135,24 +134,35 @@ void ElementwiseOpFusionPass::runOnOperation() {
return areFusableAsElementwiseOps(context, fusedOperand,
fuseMultiReduction);
};
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,

RewritePatternSet linalgFusionPatterns(context);
linalg::populateElementwiseOpsFusionPatterns(linalgFusionPatterns,
fuseElementwiseOpsControlFn);

GreedyRewriteConfig rewriteConfig;
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsAndFoldGreedily(
getOperation(), std::move(linalgFusionPatterns), rewriteConfig))) {
getOperation()->emitOpError(
"Failed to fuse elementwise ops with upstream patterns.");
return signalPassFailure();
}

// Try fuse with linalgExt patterns.
linalg::ControlFusionFn foldTransposeControlFn = [](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();

return IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer});
};
RewritePatternSet linalgExtFusionPatterns(context);
IREE::LinalgExt::populateFuseLinalgExtOpsWithTransposes(
fusionPatterns, foldTransposeControlFn);
fusionPatterns.insert<GatherFusionPattern>(context);

GreedyRewriteConfig rewriteConfig;
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
linalgExtFusionPatterns, foldTransposeControlFn);
linalgExtFusionPatterns.insert<GatherFusionPattern>(context);
if (failed(applyPatternsAndFoldGreedily(
getOperation(), std::move(fusionPatterns), rewriteConfig))) {
getOperation()->emitOpError("Failed to perform elementwise operations");
getOperation(), std::move(linalgExtFusionPatterns), rewriteConfig))) {
getOperation()->emitOpError(
"Failed to fuse elementwise ops with linalgExt patterns.");
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,54 @@ util.func public @fuse_generic_gather2(
// CHECK-NEXT: %[[RES3:[a-zA-Z0-9]+]] = arith.mulf %[[RES]], %[[RES]] : f32
// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32
// CHECK-NEXT: linalg.yield %[[RES4]] : f32

util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf16>, %k: tensor<2x10x4096x64xf16>, %quantized_v: tensor<2x10x4096x64xi32>, %quant_offset: tensor<10x64xi32>, %quant_scale: tensor<10x64xf32>, %scale: f16) -> tensor<2x10x4096x64xf16> {
// Dequantize int-quantization of V
%init_dequant = tensor.empty() : tensor<2x10x4096x64xf16>
%v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%quantized_v, %quant_offset, %quant_scale : tensor<2x10x4096x64xi32>, tensor<10x64xi32>, tensor<10x64xf32>) outs(%init_dequant : tensor<2x10x4096x64xf16>) {
^bb0(%in: i32, %in_0: i32, %in_1: f32, %out: f16):
%19 = arith.addi %in, %in_0 : i32
%20 = arith.sitofp %19 : i32 to f32
%21 = arith.mulf %20, %in_1 : f32
%22 = arith.truncf %21 : f32 to f16
linalg.yield %22 : f16
} -> tensor<2x10x4096x64xf16>

// Transpose-V
%init_transpose = tensor.empty() : tensor<2x10x64x4096xf16>
%transpose_v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%v : tensor<2x10x4096x64xf16>) outs(%init_transpose : tensor<2x10x64x4096xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<2x10x64x4096xf16>

// Attention-Transpose-V
%init_attention = tensor.empty() : tensor<2x10x4096x64xf16>
%attention = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%q, %k, %transpose_v, %scale : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>, f16) outs(%init_attention : tensor<2x10x4096x64xf16>) {
^bb0(%score: f16):
iree_linalg_ext.yield %score: f16
} -> tensor<2x10x4096x64xf16>
util.return %attention : tensor<2x10x4096x64xf16>
}

// CHECK-LABEL: util.func public @fuse_transpose_attention_to_producer
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor
// CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: tensor
// CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: tensor
// CHECK-SAME: %[[ARG5:[A-Za-z0-9]+]]: f16
// CHECK: %[[DEQUANT_V:.+]] = linalg.generic
// CHECK-SAME: indexing_maps =
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>]
// CHECK-SAME: ins(%[[ARG2]], %[[ARG3]], %[[ARG4]]
// CHECK: %[[RESULT:.+]] = iree_linalg_ext.attention
// CHECK-SAME: indexing_maps =
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[DEQUANT_V]], %[[ARG5]]
Loading

0 comments on commit 41115bb

Please sign in to comment.