diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 8b0b4e0189d2..7a67778c0585 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -220,7 +220,7 @@ jobs: --goldentime-rocm-unet-ms 419.0 \ --goldentime-rocm-clip-ms 18.5 \ --goldentime-rocm-vae-ms 337.0 \ - --goldendispatch-rocm-unet 1531 \ + --goldendispatch-rocm-unet 1602 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 246 \ --goldensize-rocm-unet-bytes 2280000 \ @@ -238,21 +238,21 @@ jobs: run: | source ${VENV_DIR}/bin/activate pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \ - --goldentime-rocm-e2e-ms 372.0 \ - --goldentime-rocm-unet-ms 95.0 \ + --goldentime-rocm-e2e-ms 330.0 \ + --goldentime-rocm-unet-ms 80.0 \ --goldentime-rocm-clip-ms 15.5 \ --goldentime-rocm-vae-ms 80.0 \ - --goldendispatch-rocm-unet 1531 \ + --goldendispatch-rocm-unet 1602 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 246 \ --goldensize-rocm-unet-bytes 2270000 \ --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ - --goldentime-rocm-punet-int8-fp16-ms 55 \ - --goldendispatch-rocm-punet-int8-fp16 1284 \ + --goldentime-rocm-punet-int8-fp16-ms 53 \ + --goldendispatch-rocm-punet-int8-fp16 1424 \ --goldensize-rocm-punet-int8-fp16-bytes 2560000 \ - --goldentime-rocm-punet-int8-fp8-ms 59 \ - --goldendispatch-rocm-punet-int8-fp8 1564 \ + --goldentime-rocm-punet-int8-fp8-ms 53 \ + --goldendispatch-rocm-punet-int8-fp8 1704 \ --goldensize-rocm-punet-int8-fp8-bytes 2800000 \ --rocm-chip gfx942 \ --log-cli-level=info \ diff --git a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir index a566203907e4..7b0944471990 100644 --- a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir +++ b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir @@ -208,6 +208,41 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran transform.yield %cont, %config : !transform.any_op, !transform.any_param } + + // Variant of matmul_like_Bx20x1024x64x1280_i8xi8xi32 from Transposed-V. + transform.named_sequence @match_matmul_like_Bx20x64x1024x1280_i8xi8xi32(%cont: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont { + ^bb0(%lhs: tensor, %rhs: tensor<20x64x1280xi8>, %out: tensor): + %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor<20x64x1280xi8>) + outs(%out : tensor) { + ^bb0(%in: i8, %in_0: i8, %acc: i32): + %18 = arith.extsi %in : i8 to i32 + %19 = arith.extsi %in_0 : i8 to i32 + %20 = arith.muli %18, %19 : i32 + %21 = arith.addi %acc, %20 : i32 + linalg.yield %21 : i32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 2, + reduction = [0, 0, 0, 0, 128], + workgroup = [1, 1, 160, 64, 0]}>, + translation_info = #iree_codegen.translation_info> + }> + > -> !transform.any_param + transform.yield %cont, %config : !transform.any_op, !transform.any_param + } + transform.named_sequence @match_matmul_like_Bx20x64x64x2048_i8xi8xi32(%cont: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont { @@ -239,6 +274,38 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran transform.yield %cont, %config : !transform.any_op, !transform.any_param } + // Variant of matmul_like_Bx20x64x64x2048_i8xi8xi32 from Transposed-V. +transform.named_sequence @match_matmul_like_Bx20x64x64x2048_transposev_i8xi8xi32(%cont: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont { + ^bb0(%lhs: tensor, %rhs: tensor<20x64x2048xi8>, %out: tensor): + %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor<20x64x2048xi8>) + outs(%out : tensor) { + ^bb0(%in: i8, %in_0: i8, %acc: i32): + %18 = arith.extsi %in : i8 to i32 + %19 = arith.extsi %in_0 : i8 to i32 + %20 = arith.muli %18, %19 : i32 + %21 = arith.addi %acc, %20 : i32 + linalg.yield %21 : i32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 1, + reduction = [0, 0, 0, 0, 128], + workgroup = [1, 1, 320, 32, 0]}>, + translation_info = #iree_codegen.translation_info}> + > -> !transform.any_param + transform.yield %cont, %config : !transform.any_op, !transform.any_param + } + transform.named_sequence @match_matmul_like_Bx10x4096x64x640_i8xi8xi32(%cont: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont { @@ -302,6 +369,10 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran , @match_matmul_like_Bx10x4096x64x640_i8xi8xi32 -> @apply_op_config , @match_matmul_like_Bx20x64x64x2048_i8xi8xi32 -> @apply_op_config + // Transpose-V generated contraction. + , @match_matmul_like_Bx20x64x1024x1280_i8xi8xi32 -> @apply_op_config + , @match_matmul_like_Bx20x64x64x2048_transposev_i8xi8xi32 -> @apply_op_config + // TUNING_MATCH_END DO NOT REMOVE : (!transform.any_op) -> (!transform.any_op) transform.yield diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h index aec6bd704e5d..8bf84cab2574 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h @@ -19,6 +19,12 @@ void populateFuseLinalgExtOpsWithTransposes( RewritePatternSet &patterns, const linalg::ControlFusionFn &controlFusionFn); +/// Bubble up transpose-like ops from LinalgExt ops (only `AttentionOp` +/// supported). +void populateBubbleTransposeFromLinalgExtOps( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFusionFn); + /// Helper struct to hold the results of collapsing an operation. struct CollapseResult { SmallVector results; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp index 2d158d54014a..bcc94ec951c0 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -101,6 +102,103 @@ struct FuseTransposeWithAttentionOp final private: linalg::ControlFusionFn controlFn; }; + +// Bubbles transpose-V out of attention to expose the more performant +// attention-transposeV. +struct BubbleTransposeVFromAttentionOp + : public OpRewritePattern { + BubbleTransposeVFromAttentionOp(MLIRContext *context, + linalg::ControlFusionFn controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFn(controlFn) {} + + LogicalResult matchAndRewrite(LinalgExt::AttentionOp attentionOp, + PatternRewriter &rewriter) const override { + // Only checking for V because we are only bubbling transpose-V. + OpOperand *valueOpOperand = &attentionOp.getValueMutable(); + if (controlFn && !controlFn(valueOpOperand)) { + return rewriter.notifyMatchFailure( + attentionOp, "Expected attentionOp and producer of V to be non-null " + "and outside dispatch."); + } + // Extract Attention indexing information. + AffineMap qMap = attentionOp.getQueryMap(); + AffineMap kMap = attentionOp.getKeyMap(); + AffineMap vMap = attentionOp.getValueMap(); + AffineMap oMap = attentionOp.getOutputMap(); + FailureOr maybeOpInfo = + AttentionOpDetail::get(qMap, kMap, vMap, oMap); + if (failed(maybeOpInfo)) { + return failure(); + } + + // Only handle single dim for K2 and N for now. + if (maybeOpInfo->getK2Dims().size() != 1 || + maybeOpInfo->getNDims().size() != 1) { + return failure(); + } + // Check that V has standard map/non transposed V. + AffineExpr k2Dim = + rewriter.getAffineDimExpr(maybeOpInfo->getK2Dims().back()); + AffineExpr nDim = rewriter.getAffineDimExpr(maybeOpInfo->getNDims().back()); + int64_t vRank = vMap.getNumResults(); + // TODO: This check is quite conservative, in the future we should simply + // do vMap.getResultPosition(k2Dim) > vMap.getResultPosition(nDim). + if (vMap.getResult(vRank - 1) != nDim || + vMap.getResult(vRank - 2) != k2Dim) { + return failure(); + } + + // Get dimension positions to prepare for transpose. + std::optional maybeK2Pos = vMap.getResultPosition(k2Dim); + std::optional maybeNPos = vMap.getResultPosition(nDim); + assert(maybeK2Pos.has_value() && maybeNPos.has_value() && + "Expected K2 dim and N dim to be in V-map."); + int64_t k2Pos = maybeK2Pos.value(); + int64_t nPos = maybeNPos.value(); + SmallVector perm = llvm::to_vector(llvm::seq(0, vRank)); + std::swap(perm[k2Pos], perm[nPos]); + + // Expose transposeOp for V. + Location loc = attentionOp.getLoc(); + Value value = attentionOp.getValue(); + auto valueType = dyn_cast(value.getType()); + auto valueElType = valueType.getElementType(); + SmallVector transVShape = + tensor::getMixedSizes(rewriter, loc, value); + applyPermutationToVector(transVShape, perm); + Value initTransV = + rewriter.create(loc, transVShape, valueElType) + .getResult(); + Value transposeV = + rewriter.create(loc, value, initTransV, perm) + ->getResult(0); + + // Generate transpose V map. + SmallVector newExprs = + applyPermutation(vMap.getResults(), perm); + AffineMap transposedVMap = + AffineMap::get(vMap.getNumDims(), vMap.getNumSymbols(), newExprs, + rewriter.getContext()); + + // Modify attention to have transposed V inputs and mapping. + int64_t valueIndex = valueOpOperand->getOperandNumber(); + rewriter.modifyOpInPlace(attentionOp, [&]() { + SmallVector newIndexingMaps = + attentionOp.getIndexingMapsArray(); + newIndexingMaps[valueIndex] = transposedVMap; + attentionOp.setIndexingMapsAttr( + rewriter.getAffineMapArrayAttr(newIndexingMaps)); + attentionOp.setOperand(valueIndex, transposeV); + }); + return success(); + } + +private: + linalg::ControlFusionFn controlFn; +}; + } // namespace void populateFuseLinalgExtOpsWithTransposes( @@ -110,4 +208,11 @@ void populateFuseLinalgExtOpsWithTransposes( controlFusionFn); } +void populateBubbleTransposeFromLinalgExtOps( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFusionFn) { + patterns.add(patterns.getContext(), + controlFusionFn); +} + } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp index 41db09f07a16..3c1a783ecba3 100644 --- a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp @@ -104,7 +104,6 @@ struct GatherFusionPattern final : public OpRewritePattern { void ElementwiseOpFusionPass::runOnOperation() { MLIRContext *context = &getContext(); - RewritePatternSet fusionPatterns(context); // Only fuse operations where all uses of the producer are generic // operations. If an operation is used in a named op, it will be computed // anyway, so the consumers can just use that value. @@ -135,24 +134,35 @@ void ElementwiseOpFusionPass::runOnOperation() { return areFusableAsElementwiseOps(context, fusedOperand, fuseMultiReduction); }; - linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, + + RewritePatternSet linalgFusionPatterns(context); + linalg::populateElementwiseOpsFusionPatterns(linalgFusionPatterns, fuseElementwiseOpsControlFn); + GreedyRewriteConfig rewriteConfig; + rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit; + if (failed(applyPatternsAndFoldGreedily( + getOperation(), std::move(linalgFusionPatterns), rewriteConfig))) { + getOperation()->emitOpError( + "Failed to fuse elementwise ops with upstream patterns."); + return signalPassFailure(); + } + + // Try fuse with linalgExt patterns. linalg::ControlFusionFn foldTransposeControlFn = [](OpOperand *fusedOperand) { Operation *producer = fusedOperand->get().getDefiningOp(); Operation *consumer = fusedOperand->getOwner(); return IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer}); }; + RewritePatternSet linalgExtFusionPatterns(context); IREE::LinalgExt::populateFuseLinalgExtOpsWithTransposes( - fusionPatterns, foldTransposeControlFn); - fusionPatterns.insert(context); - - GreedyRewriteConfig rewriteConfig; - rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit; + linalgExtFusionPatterns, foldTransposeControlFn); + linalgExtFusionPatterns.insert(context); if (failed(applyPatternsAndFoldGreedily( - getOperation(), std::move(fusionPatterns), rewriteConfig))) { - getOperation()->emitOpError("Failed to perform elementwise operations"); + getOperation(), std::move(linalgExtFusionPatterns), rewriteConfig))) { + getOperation()->emitOpError( + "Failed to fuse elementwise ops with linalgExt patterns."); return signalPassFailure(); } } diff --git a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir index 8b556a03835d..096c882ab219 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir @@ -207,3 +207,54 @@ util.func public @fuse_generic_gather2( // CHECK-NEXT: %[[RES3:[a-zA-Z0-9]+]] = arith.mulf %[[RES]], %[[RES]] : f32 // CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32 // CHECK-NEXT: linalg.yield %[[RES4]] : f32 + +util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf16>, %k: tensor<2x10x4096x64xf16>, %quantized_v: tensor<2x10x4096x64xi32>, %quant_offset: tensor<10x64xi32>, %quant_scale: tensor<10x64xf32>, %scale: f16) -> tensor<2x10x4096x64xf16> { + // Dequantize int-quantization of V + %init_dequant = tensor.empty() : tensor<2x10x4096x64xf16> + %v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%quantized_v, %quant_offset, %quant_scale : tensor<2x10x4096x64xi32>, tensor<10x64xi32>, tensor<10x64xf32>) outs(%init_dequant : tensor<2x10x4096x64xf16>) { + ^bb0(%in: i32, %in_0: i32, %in_1: f32, %out: f16): + %19 = arith.addi %in, %in_0 : i32 + %20 = arith.sitofp %19 : i32 to f32 + %21 = arith.mulf %20, %in_1 : f32 + %22 = arith.truncf %21 : f32 to f16 + linalg.yield %22 : f16 + } -> tensor<2x10x4096x64xf16> + + // Transpose-V + %init_transpose = tensor.empty() : tensor<2x10x64x4096xf16> + %transpose_v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%v : tensor<2x10x4096x64xf16>) outs(%init_transpose : tensor<2x10x64x4096xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<2x10x64x4096xf16> + + // Attention-Transpose-V + %init_attention = tensor.empty() : tensor<2x10x4096x64xf16> + %attention = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%q, %k, %transpose_v, %scale : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>, f16) outs(%init_attention : tensor<2x10x4096x64xf16>) { + ^bb0(%score: f16): + iree_linalg_ext.yield %score: f16 + } -> tensor<2x10x4096x64xf16> + util.return %attention : tensor<2x10x4096x64xf16> +} + +// CHECK-LABEL: util.func public @fuse_transpose_attention_to_producer +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG5:[A-Za-z0-9]+]]: f16 +// CHECK: %[[DEQUANT_V:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>] +// CHECK-SAME: ins(%[[ARG2]], %[[ARG3]], %[[ARG4]] +// CHECK: %[[RESULT:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[DEQUANT_V]], %[[ARG5]] diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp index 846233841732..265ddbbc5890 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "llvm/Support/Debug.h" @@ -1087,6 +1088,15 @@ void PropagateLinalgTransposePass::runOnOperation() { linalg::populateFoldReshapeOpsByExpansionPatterns(bubblingPatterns, reshapePropagationFn); linalg::FillOp::getCanonicalizationPatterns(bubblingPatterns, context); + linalg::ControlFusionFn bubbleTransposeControlFn = + [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + Operation *consumer = fusedOperand->getOwner(); + + return IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer}); + }; + IREE::LinalgExt::populateBubbleTransposeFromLinalgExtOps( + bubblingPatterns, bubbleTransposeControlFn); bubblingPatterns.insert( context, enableAggressivePropagation); bubblingPatterns.insert(context); diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir index 6b9571666808..16f37473eb47 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir @@ -665,3 +665,47 @@ util.func public @bubble_transpose_to_broadcast_elementwise(%arg0: tensor<2x3x4x // BUBBLE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x4xf32> // BUBBLE: arith.addf // BUBBLE: util.return %[[ELEM]] : tensor<3x4x2xf32> + +// ----- + +util.func public @bubble_transpose_v_from_attention(%q: tensor<2x10x4096x64xf16>, %k: tensor<2x10x4096x64xf16>, %quantized_v: tensor<2x10x4096x64xi32>, %quant_offset: tensor<10x64xi32>, %quant_scale: tensor<10x64xf32>, %scale: f16) -> tensor<2x10x4096x64xf16> { + // Dequantize int-quantization of V + %init_dequant = tensor.empty() : tensor<2x10x4096x64xf16> + %v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%quantized_v, %quant_offset, %quant_scale : tensor<2x10x4096x64xi32>, tensor<10x64xi32>, tensor<10x64xf32>) outs(%init_dequant : tensor<2x10x4096x64xf16>) { + ^bb0(%in: i32, %in_0: i32, %in_1: f32, %out: f16): + %19 = arith.addi %in, %in_0 : i32 + %20 = arith.sitofp %19 : i32 to f32 + %21 = arith.mulf %20, %in_1 : f32 + %22 = arith.truncf %21 : f32 to f16 + linalg.yield %22 : f16 + } -> tensor<2x10x4096x64xf16> + + // Attention with transposed V + %init_attention = tensor.empty() : tensor<2x10x4096x64xf16> + %attention = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%q, %k, %v, %scale : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, f16) outs(%init_attention : tensor<2x10x4096x64xf16>) { + ^bb0(%score: f16): + iree_linalg_ext.yield %score: f16 + } -> tensor<2x10x4096x64xf16> + util.return %attention : tensor<2x10x4096x64xf16> +} + + +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> + +// CHECK-LABEL: util.func public @bubble_transpose_v_from_attention( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x10x4096x64xf16>, %[[ARG1:.*]]: tensor<2x10x4096x64xf16>, %[[ARG2:.*]]: tensor<2x10x4096x64xi32>, +// CHECK-SAME: %[[ARG3:.*]]: tensor<10x64xi32>, %[[ARG4:.*]]: tensor<10x64xf32>, %[[ARG5:.*]]: f16) -> tensor<2x10x4096x64xf16> { +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x10x4096x64xf16> +// CHECK: %[[DEQUANT_V:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG2]], %[[ARG3]], %[[ARG4]] : tensor<2x10x4096x64xi32>, tensor<10x64xi32>, tensor<10x64xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<2x10x4096x64xf16>) +// CHECK: %[[TRANS_V:.*]] = linalg.transpose ins(%[[DEQUANT_V]] : tensor<2x10x4096x64xf16>) outs({{.*}} : tensor<2x10x64x4096xf16>) permutation = [0, 1, 3, 2] +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention +// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[TRANS_V]], %[[ARG5]] : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>, f16) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x10x4096x64xf16>) +// CHECK: util.return %[[ATTN]] : tensor<2x10x4096x64xf16>