diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp index 3f6c727b0a0b..78b3b6ff457a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp @@ -32,34 +32,99 @@ namespace mlir::iree_compiler { namespace { -/// Implementation of tile root and fuse producers and consumers greedily. -static LogicalResult tileRootAndFuseProducerConsumerUsingSCF( - RewriterBase &rewriter, TilingInterface root, - const scf::SCFTileAndFuseOptions &options) { - - // This transformation is only valid for ops that return values (i.e. not - // valid to use with operations that have memref operands). - if (!root->getNumResults()) { - return rewriter.notifyMatchFailure( - root, "invalid pattern for op with no results"); +/// Starting from `op` walk all operands backwards to find all +/// potentially fusable operations, i.e. operations that implement +/// the `TilingInterface`. +static void collectTiledAndFusedOps(Operation *rootOp, + llvm::SmallDenseSet &result) { + SmallVector worklist; + worklist.push_back(rootOp); + result.insert(rootOp); + while (!worklist.empty()) { + Operation *current = worklist.pop_back_val(); + for (OpOperand &operand : current->getOpOperands()) { + Operation *producer = operand.get().getDefiningOp(); + if (!producer || !isa(producer) || + result.count(producer)) + continue; + worklist.push_back(producer); + result.insert(producer); + } } +} - // 1. Tile root op and Fuse Producers. - FailureOr tiledResults = - scf::tileConsumerAndFuseProducersUsingSCF(rewriter, root, options); +/// Tile the root operation and fuse the producers of the root operation. +/// If `onlyFuseProducerInputOperands` is set, only fuse producer input +/// operands. Returns the tiled operation to be used for fusing consumers. +FailureOr +tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp, + int64_t tilingLevel, + bool onlyFuseProducerInputOperands) { + mlir::DominanceInfo dominanceInfo(rootOp); + llvm::SmallDenseSet tiledAndFusedOps; + collectTiledAndFusedOps(rootOp, tiledAndFusedOps); + + llvm::DenseSet yieldReplacementsFor; + for (auto op : tiledAndFusedOps) { + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + return dominanceInfo.properlyDominates(rootOp, user); + })) { + yieldReplacementsFor.insert(op); + } + } + + SmallVector tileSizes = + getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel, + rootOp); + + // Pad the tile sizes with zero. + auto zero = rewriter.getIndexAttr(0); + int64_t numLoops = rootOp.getLoopIteratorTypes().size(); + if (tileSizes.size() > numLoops) { + LLVM_DEBUG(llvm::dbgs() + << "tile sizes size " << tileSizes.size() + << " exceeds the number of loops " << numLoops << "\n"); + return failure(); + } + tileSizes.resize(numLoops, zero); + + scf::SCFTilingOptions tilingOptions; + tilingOptions.setTileSizes(tileSizes); + + scf::SCFTileAndFuseOptions tileAndFuseOptions; + tileAndFuseOptions.setTilingOptions(tilingOptions); + + scf::SCFTileAndFuseOptions::ControlFnTy controlFn = + [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, + bool isDestinationOperand) { + Operation *owner = originalProducer.getOwner(); + bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); + // Do not fuse destination operands if onlyFuseProducerInputOperands is + // true. + bool shouldFuse = + !(onlyFuseProducerInputOperands && isDestinationOperand); + return std::make_tuple(shouldFuse, yieldProducerReplacement); + }; + tileAndFuseOptions.setFusionControlFn(controlFn); + FailureOr tiledResults = + scf::tileConsumerAndFuseProducersUsingSCF(rewriter, rootOp, + tileAndFuseOptions); if (failed(tiledResults)) { - return rewriter.notifyMatchFailure( - root, "failed to tile root and fuse producers"); + return failure(); } - // 2. Replace the producers with the tiled verison. - SmallVector opsToReplace = {root}; + // Perform the replacement of tiled and fused values. + SmallVector opsToReplace{rootOp}; llvm::append_range(opsToReplace, tiledResults->fusedProducers); for (Operation *toReplace : opsToReplace) { for (OpResult res : toReplace->getResults()) if (auto replacement = tiledResults->replacements.lookup(res)) { - rewriter.replaceAllUsesWith(res, replacement); + Operation *replacementOp = replacement.getDefiningOp(); + rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) { + Operation *user = use.getOwner(); + return dominanceInfo.properlyDominates(replacementOp, user); + }); } if (toReplace->use_empty()) { @@ -67,13 +132,18 @@ static LogicalResult tileRootAndFuseProducerConsumerUsingSCF( } } - // 3. Typically, the consumers of the tiled operation are slices of the - // results of the tiled operation. These are expressed in IR using - // `tensor.insert_slice` operations, whose outputs are the operands of the - // untiled operation. Create a worklist of these `tensor.insert_siices` - // operations. If the consumers of the source of the `tensor.insert_slices` - // can be tiled such that the tiled value is generated in-place, that - // effectively tiles + fuses the operations. + return tiledResults->tiledAndFusedOps.front(); +} + +static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) { + + // Typically, the consumers of the tiled operation are slices of the + // results of the tiled operation. These are expressed in IR using + // `tensor.insert_slice` operations, whose outputs are the operands of the + // untiled operation. Create a worklist of these `tensor.insert_siices` + // operations. If the consumers of the source of the `tensor.insert_slices` + // can be tiled such that the tiled value is generated in-place, that + // effectively tiles + fuses the operations. auto addCandidateSlices = [](Operation *fusedOp, std::queue &candidates) { for (auto *userOp : fusedOp->getResults().getUsers()) { @@ -86,7 +156,7 @@ static LogicalResult tileRootAndFuseProducerConsumerUsingSCF( // Collect the candidate slices which can be potential consumers that can be // fused. std::queue candidates; - addCandidateSlices(tiledResults->tiledAndFusedOps.front(), candidates); + addCandidateSlices(tiledOp, candidates); while (!candidates.empty()) { @@ -112,42 +182,44 @@ static LogicalResult tileRootAndFuseProducerConsumerUsingSCF( addCandidateSlices(fusedResult->tiledAndFusedConsumerOperand->getOwner(), candidates); } - return success(); } -static LogicalResult tileRootAndFuseProducerConsumer(IRRewriter &rewriter, - TilingInterface rootOp, - int64_t tilingLevel) { +/// Implementation of tile root and fuse producers and consumers greedily. +/// If `onlyFuseProducerInputOperands` is set, only fuse producer input operands +/// and disable consumer fusion. +static LogicalResult tileRootAndFuse(IRRewriter &rewriter, + TilingInterface rootOp, + int64_t tilingLevel, + bool onlyFuseProducerInputOperands) { - SmallVector tileSizes = - getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel, - rootOp); - int64_t numLoops = rootOp.getLoopIteratorTypes().size(); - if (tileSizes.size() > numLoops) - return failure(); + FailureOr tiledOp = tileRootAndFuseProducers( + rewriter, rootOp, tilingLevel, onlyFuseProducerInputOperands); - scf::SCFTilingOptions tilingOptions; - tilingOptions.setTileSizes(tileSizes); + if (failed(tiledOp)) + return failure(); - scf::SCFTileAndFuseOptions tileAndFuseOptions; - tileAndFuseOptions.setTilingOptions(tilingOptions); + if (!onlyFuseProducerInputOperands) + fuseConsumers(rewriter, tiledOp.value()); - return tileRootAndFuseProducerConsumerUsingSCF(rewriter, rootOp, - tileAndFuseOptions); + return success(); } /// This pass starts with the first TilingInterface operation that has /// lowering_config attribute, tiles the op and fuses its consumers and -/// producers recursively. The `tilingLevel` must be specified. It picks the -/// `tilingLevel`-th list as tiling sizes from lowering_config. +/// producers recursively. If the `onlyFuseProducerInputOperands` is set, it +/// only fuses producer input operands and disables consumer fusion. The +/// `tilingLevel` must be specified. It picks the `tilingLevel`-th list as +/// tiling sizes from lowering_config. struct LLVMCPUTileRootAndFuseProducerConsumer : impl::LLVMCPUTileRootAndFuseProducerConsumerPassBase< LLVMCPUTileRootAndFuseProducerConsumer> { using impl::LLVMCPUTileRootAndFuseProducerConsumerPassBase< LLVMCPUTileRootAndFuseProducerConsumer>:: LLVMCPUTileRootAndFuseProducerConsumerPassBase; - explicit LLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) { + explicit LLVMCPUTileRootAndFuseProducerConsumer( + int64_t tilingLevel, bool onlyFuseProducerInputOperands) { this->tilingLevel = tilingLevel; + this->onlyFuseProducerInputOperands = onlyFuseProducerInputOperands; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(rootOp.value()), - tilingLevel.getValue()))) { + tilingLevel.getValue(), onlyFuseProducerInputOperands.getValue()))) { funcOp.emitError() << "tiling of level " << tilingLevel.getValue() << " failed\n"; return signalPassFailure(); @@ -212,6 +284,12 @@ void LLVMCPUTileRootAndFuseProducerConsumer::runOnOperation() { std::unique_ptr> createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) { - return std::make_unique(tilingLevel); + return std::make_unique( + tilingLevel, /*onlyFuseProducerInputOperands=*/false); +} +std::unique_ptr> +createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel) { + return std::make_unique( + tilingLevel, /*onlyFuseProducerInputOperands=*/true); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index aeb9d7fe9c49..70cbe0d237cb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -451,8 +451,8 @@ void addConvTileAndDecomposeExpertPassPipeline( funcPassManager.addPass(createFuseTensorPadWithConsumerPass()); funcPassManager.addPass(createConcretizePadResultShapePass()); - funcPassManager.addPass( - createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel())); + funcPassManager.addPass(createLLVMCPUTileRootAndFuseInputOperands( + tilingConfig.getVectorReductionLevel())); funcPassManager.addPass( createLLVMCPUTileAndFusePass(tilingConfig.getVectorInnerParallelLevel())); funcPassManager.addPass(createDecomposeConvolutionToLowerDimOpsPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index bac30de708fa..a8cb91a7608e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -41,6 +41,9 @@ createLLVMCPUTileAndFusePass(int64_t tilingLevel); std::unique_ptr> createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel); +std::unique_ptr> +createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel); + std::unique_ptr> createLLVMCPUVerifyVectorSizeLegalityPass( int64_t maxAllowedNumberOfNativeVectors); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td index c85666bb1590..81dbdcf9778c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td @@ -140,13 +140,20 @@ def LLVMCPUTileAndFusePass : ]; } -def LLVMCPUTileRootAndFuseProducerConsumerPass : - InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer", "mlir::FunctionOpInterface"> { - let summary = "Pass to tile root op and fuse with producer and consumer TilingInterface ops."; - let options = [ - Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1", - "Use default tiling level used to retrieve the configuration from lowering_config"> - ]; +def LLVMCPUTileRootAndFuseProducerConsumerPass + : InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer", + "mlir::FunctionOpInterface"> { + let summary = "Pass to tile root op and fuse with producer and consumer " + "TilingInterface ops."; + let options = + [Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1", + "Use default tiling level used to retrieve the configuration " + "from lowering_config">, + Option<"onlyFuseProducerInputOperands", + "only-fuse-producer-input-operands", "bool", + /*default=*/"false", + "Specifies if we only want to fuse producer's input operands. " + "This is helpful to tile&fuse in case of reduction dimensions.">]; } def LLVMCPUVerifyVectorSizeLegalityPass : diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir index 8710da6e4b08..4d8805d513c1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir @@ -1,4 +1,6 @@ // RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=0}), canonicalize)" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=2 only-fuse-producer-input-operands=true}), canonicalize)" --split-input-file %s | FileCheck %s --check-prefix=CHECK-REDUCTION + #config1 = #iree_codegen.lowering_config #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> @@ -30,7 +32,8 @@ func.func @mmt4d_bias_relu(%arg0: tensor, %arg1: tensor + +#config2 = #iree_codegen.lowering_config #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> func.func @quantized_matmul(%arg0: tensor<2x4x128x16x1xi8>, %arg1: tensor<2x4x16xf32>, %arg2: tensor<2x4x16xf32>, %arg3: tensor<2x688x128x16x1xi8>, %arg4: tensor<2x688x16xf32>, %arg5: tensor<2x688x16xf32>) -> tensor<2x11008x64xf32> { @@ -61,7 +64,7 @@ func.func @quantized_matmul(%arg0: tensor<2x4x128x16x1xi8>, %arg1: tensor<2x4x16 } -> tensor<2x688x128x16x1xf32> %4 = tensor.empty() : tensor<2x4x688x16x16xf32> %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32> - %6 = linalg.batch_mmt4d {lowering_config = #config} ins(%1, %3 : tensor<2x4x128x16x1xf32>, tensor<2x688x128x16x1xf32>) outs(%5 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32> + %6 = linalg.batch_mmt4d {lowering_config = #config2} ins(%1, %3 : tensor<2x4x128x16x1xf32>, tensor<2x688x128x16x1xf32>) outs(%5 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32> %7 = tensor.empty() : tensor<2x11008x64xf32> %unpack = tensor.unpack %6 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 16] into %7 : tensor<2x4x688x16x16xf32> -> tensor<2x11008x64xf32> return %unpack : tensor<2x11008x64xf32> @@ -75,3 +78,42 @@ func.func @quantized_matmul(%arg0: tensor<2x4x128x16x1xi8>, %arg1: tensor<2x4x16 // CHECK: linalg.batch_mmt4d // CHECK: tensor.unpack // CHECK: } + + +// ----- + +#config3 = #iree_codegen.lowering_config +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @dequant_avgpool(%arg0: tensor<1x320x65x65xi8>) -> tensor<1x320x1x1xf32> { + %cst = arith.constant 1.250000e-01 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %c5408000 = arith.constant 5408000 : index + %c0 = arith.constant 0 : index + %0 = tensor.empty() : tensor<1x320x1x1xf32> + %1 = tensor.empty() : tensor<65x65xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<65x65xf32>) -> tensor<65x65xf32> + %3 = tensor.empty() : tensor<1x320x65x65xf32> + %4 = tensor.empty() : tensor<1x320x1x1xf32> + %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x320x65x65xi8>) outs(%3 : tensor<1x320x65x65xf32>) { + ^bb0(%in: i8, %out: f32): + %7 = arith.extsi %in : i8 to i32 + %8 = arith.sitofp %7 : i32 to f32 + %9 = arith.mulf %8, %cst : f32 + linalg.yield %9 : f32 + } -> tensor<1x320x65x65xf32> + %6 = linalg.pooling_nchw_sum {lowering_config = #config3} ins(%5, %2 : tensor<1x320x65x65xf32>, tensor<65x65xf32>) outs(%4 : tensor<1x320x1x1xf32>) -> tensor<1x320x1x1xf32> + return %6 : tensor<1x320x1x1xf32> +} + +// CHECK-REDUCTION-LABEL: func.func @dequant_avgpool( +// CHECK-REDUCTION-SAME: { +// CHECK-REDUCTION: scf.for +// CHECK-REDUCTION-SAME: { +// CHECK-REDUCTION: scf.for +// CHECK-REDUCTION-SAME: { +// CHECK-REDUCTION: linalg.generic +// CHECK-REDUCTION: %[[POOL:.+]] = linalg.pooling_nchw_sum +// CHECK-REDUCTION: scf.yield %[[POOL]] +// CHECK-REDUCTION: } +// CHECK-REDUCTION: } +// CHECK-REDUCTION: }