Skip to content

Commit

Permalink
[LLVMCPU] Add option onlyFuseProducerInputOperands to tileRootFuseC…
Browse files Browse the repository at this point in the history
…onsumerProducer Pass (iree-org#18114)

Previously, we only tilled the reduction tile sizes and did not fuse
them with the producers from the input operands. It led to transfer
read/write with large vector sizes since the dequant operation
materialised its own tensor and wasn't fused inside the reduction loop.
Adds a `onlyFuseProducerInputOperands` option to the
tile-root-and-fuse-consumer-producer-pass.
If the option is set to true, it tiles the reduction dimension and fuses
the operations arising from the input operand of the already tiled
operation. Issue link: iree-org#18005
  • Loading branch information
pashu123 authored Aug 13, 2024
1 parent 6ac6be6 commit ad2f0f8
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,48 +32,118 @@ namespace mlir::iree_compiler {

namespace {

/// Implementation of tile root and fuse producers and consumers greedily.
static LogicalResult tileRootAndFuseProducerConsumerUsingSCF(
RewriterBase &rewriter, TilingInterface root,
const scf::SCFTileAndFuseOptions &options) {

// This transformation is only valid for ops that return values (i.e. not
// valid to use with operations that have memref operands).
if (!root->getNumResults()) {
return rewriter.notifyMatchFailure(
root, "invalid pattern for op with no results");
/// Starting from `op` walk all operands backwards to find all
/// potentially fusable operations, i.e. operations that implement
/// the `TilingInterface`.
static void collectTiledAndFusedOps(Operation *rootOp,
llvm::SmallDenseSet<Operation *> &result) {
SmallVector<Operation *> worklist;
worklist.push_back(rootOp);
result.insert(rootOp);
while (!worklist.empty()) {
Operation *current = worklist.pop_back_val();
for (OpOperand &operand : current->getOpOperands()) {
Operation *producer = operand.get().getDefiningOp();
if (!producer || !isa<TilingInterface>(producer) ||
result.count(producer))
continue;
worklist.push_back(producer);
result.insert(producer);
}
}
}

// 1. Tile root op and Fuse Producers.
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, root, options);
/// Tile the root operation and fuse the producers of the root operation.
/// If `onlyFuseProducerInputOperands` is set, only fuse producer input
/// operands. Returns the tiled operation to be used for fusing consumers.
FailureOr<Operation *>
tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp,
int64_t tilingLevel,
bool onlyFuseProducerInputOperands) {
mlir::DominanceInfo dominanceInfo(rootOp);
llvm::SmallDenseSet<Operation *> tiledAndFusedOps;
collectTiledAndFusedOps(rootOp, tiledAndFusedOps);

llvm::DenseSet<Operation *> yieldReplacementsFor;
for (auto op : tiledAndFusedOps) {
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
return dominanceInfo.properlyDominates(rootOp, user);
})) {
yieldReplacementsFor.insert(op);
}
}

SmallVector<OpFoldResult> tileSizes =
getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel,
rootOp);

// Pad the tile sizes with zero.
auto zero = rewriter.getIndexAttr(0);
int64_t numLoops = rootOp.getLoopIteratorTypes().size();
if (tileSizes.size() > numLoops) {
LLVM_DEBUG(llvm::dbgs()
<< "tile sizes size " << tileSizes.size()
<< " exceeds the number of loops " << numLoops << "\n");
return failure();
}
tileSizes.resize(numLoops, zero);

scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes);

scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(tilingOptions);

scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand) {
Operation *owner = originalProducer.getOwner();
bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
// Do not fuse destination operands if onlyFuseProducerInputOperands is
// true.
bool shouldFuse =
!(onlyFuseProducerInputOperands && isDestinationOperand);
return std::make_tuple(shouldFuse, yieldProducerReplacement);
};
tileAndFuseOptions.setFusionControlFn(controlFn);

FailureOr<scf::SCFTileAndFuseResult> tiledResults =
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, rootOp,
tileAndFuseOptions);
if (failed(tiledResults)) {
return rewriter.notifyMatchFailure(
root, "failed to tile root and fuse producers");
return failure();
}

// 2. Replace the producers with the tiled verison.
SmallVector<Operation *> opsToReplace = {root};
// Perform the replacement of tiled and fused values.
SmallVector<Operation *> opsToReplace{rootOp};
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
for (Operation *toReplace : opsToReplace) {
for (OpResult res : toReplace->getResults())
if (auto replacement = tiledResults->replacements.lookup(res)) {
rewriter.replaceAllUsesWith(res, replacement);
Operation *replacementOp = replacement.getDefiningOp();
rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
Operation *user = use.getOwner();
return dominanceInfo.properlyDominates(replacementOp, user);
});
}

if (toReplace->use_empty()) {
rewriter.eraseOp(toReplace);
}
}

// 3. Typically, the consumers of the tiled operation are slices of the
// results of the tiled operation. These are expressed in IR using
// `tensor.insert_slice` operations, whose outputs are the operands of the
// untiled operation. Create a worklist of these `tensor.insert_siices`
// operations. If the consumers of the source of the `tensor.insert_slices`
// can be tiled such that the tiled value is generated in-place, that
// effectively tiles + fuses the operations.
return tiledResults->tiledAndFusedOps.front();
}

static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {

// Typically, the consumers of the tiled operation are slices of the
// results of the tiled operation. These are expressed in IR using
// `tensor.insert_slice` operations, whose outputs are the operands of the
// untiled operation. Create a worklist of these `tensor.insert_siices`
// operations. If the consumers of the source of the `tensor.insert_slices`
// can be tiled such that the tiled value is generated in-place, that
// effectively tiles + fuses the operations.
auto addCandidateSlices = [](Operation *fusedOp,
std::queue<tensor::InsertSliceOp> &candidates) {
for (auto *userOp : fusedOp->getResults().getUsers()) {
Expand All @@ -86,7 +156,7 @@ static LogicalResult tileRootAndFuseProducerConsumerUsingSCF(
// Collect the candidate slices which can be potential consumers that can be
// fused.
std::queue<tensor::InsertSliceOp> candidates;
addCandidateSlices(tiledResults->tiledAndFusedOps.front(), candidates);
addCandidateSlices(tiledOp, candidates);

while (!candidates.empty()) {

Expand All @@ -112,42 +182,44 @@ static LogicalResult tileRootAndFuseProducerConsumerUsingSCF(
addCandidateSlices(fusedResult->tiledAndFusedConsumerOperand->getOwner(),
candidates);
}
return success();
}

static LogicalResult tileRootAndFuseProducerConsumer(IRRewriter &rewriter,
TilingInterface rootOp,
int64_t tilingLevel) {
/// Implementation of tile root and fuse producers and consumers greedily.
/// If `onlyFuseProducerInputOperands` is set, only fuse producer input operands
/// and disable consumer fusion.
static LogicalResult tileRootAndFuse(IRRewriter &rewriter,
TilingInterface rootOp,
int64_t tilingLevel,
bool onlyFuseProducerInputOperands) {

SmallVector<OpFoldResult> tileSizes =
getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel,
rootOp);
int64_t numLoops = rootOp.getLoopIteratorTypes().size();
if (tileSizes.size() > numLoops)
return failure();
FailureOr<Operation *> tiledOp = tileRootAndFuseProducers(
rewriter, rootOp, tilingLevel, onlyFuseProducerInputOperands);

scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes);
if (failed(tiledOp))
return failure();

scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(tilingOptions);
if (!onlyFuseProducerInputOperands)
fuseConsumers(rewriter, tiledOp.value());

return tileRootAndFuseProducerConsumerUsingSCF(rewriter, rootOp,
tileAndFuseOptions);
return success();
}

/// This pass starts with the first TilingInterface operation that has
/// lowering_config attribute, tiles the op and fuses its consumers and
/// producers recursively. The `tilingLevel` must be specified. It picks the
/// `tilingLevel`-th list as tiling sizes from lowering_config.
/// producers recursively. If the `onlyFuseProducerInputOperands` is set, it
/// only fuses producer input operands and disables consumer fusion. The
/// `tilingLevel` must be specified. It picks the `tilingLevel`-th list as
/// tiling sizes from lowering_config.
struct LLVMCPUTileRootAndFuseProducerConsumer
: impl::LLVMCPUTileRootAndFuseProducerConsumerPassBase<
LLVMCPUTileRootAndFuseProducerConsumer> {
using impl::LLVMCPUTileRootAndFuseProducerConsumerPassBase<
LLVMCPUTileRootAndFuseProducerConsumer>::
LLVMCPUTileRootAndFuseProducerConsumerPassBase;
explicit LLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) {
explicit LLVMCPUTileRootAndFuseProducerConsumer(
int64_t tilingLevel, bool onlyFuseProducerInputOperands) {
this->tilingLevel = tilingLevel;
this->onlyFuseProducerInputOperands = onlyFuseProducerInputOperands;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, affine::AffineDialect,
Expand Down Expand Up @@ -186,9 +258,9 @@ void LLVMCPUTileRootAndFuseProducerConsumer::runOnOperation() {
return signalPassFailure();
}

if (failed(tileRootAndFuseProducerConsumer(
if (failed(tileRootAndFuse(
rewriter, dyn_cast<TilingInterface>(rootOp.value()),
tilingLevel.getValue()))) {
tilingLevel.getValue(), onlyFuseProducerInputOperands.getValue()))) {
funcOp.emitError() << "tiling of level " << tilingLevel.getValue()
<< " failed\n";
return signalPassFailure();
Expand All @@ -212,6 +284,12 @@ void LLVMCPUTileRootAndFuseProducerConsumer::runOnOperation() {

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) {
return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(tilingLevel);
return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(
tilingLevel, /*onlyFuseProducerInputOperands=*/false);
}
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel) {
return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(
tilingLevel, /*onlyFuseProducerInputOperands=*/true);
}
} // namespace mlir::iree_compiler
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ void addConvTileAndDecomposeExpertPassPipeline(
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
funcPassManager.addPass(createConcretizePadResultShapePass());

funcPassManager.addPass(
createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel()));
funcPassManager.addPass(createLLVMCPUTileRootAndFuseInputOperands(
tilingConfig.getVectorReductionLevel()));
funcPassManager.addPass(
createLLVMCPUTileAndFusePass(tilingConfig.getVectorInnerParallelLevel()));
funcPassManager.addPass(createDecomposeConvolutionToLowerDimOpsPass());
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ createLLVMCPUTileAndFusePass(int64_t tilingLevel);
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel);

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel);

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUVerifyVectorSizeLegalityPass(
int64_t maxAllowedNumberOfNativeVectors);
Expand Down
21 changes: 14 additions & 7 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,20 @@ def LLVMCPUTileAndFusePass :
];
}

def LLVMCPUTileRootAndFuseProducerConsumerPass :
InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer", "mlir::FunctionOpInterface"> {
let summary = "Pass to tile root op and fuse with producer and consumer TilingInterface ops.";
let options = [
Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1",
"Use default tiling level used to retrieve the configuration from lowering_config">
];
def LLVMCPUTileRootAndFuseProducerConsumerPass
: InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer",
"mlir::FunctionOpInterface"> {
let summary = "Pass to tile root op and fuse with producer and consumer "
"TilingInterface ops.";
let options =
[Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1",
"Use default tiling level used to retrieve the configuration "
"from lowering_config">,
Option<"onlyFuseProducerInputOperands",
"only-fuse-producer-input-operands", "bool",
/*default=*/"false",
"Specifies if we only want to fuse producer's input operands. "
"This is helpful to tile&fuse in case of reduction dimensions.">];
}

def LLVMCPUVerifyVectorSizeLegalityPass :
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=0}), canonicalize)" --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=2 only-fuse-producer-input-operands=true}), canonicalize)" --split-input-file %s | FileCheck %s --check-prefix=CHECK-REDUCTION


#config1 = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
Expand Down Expand Up @@ -30,7 +32,8 @@ func.func @mmt4d_bias_relu(%arg0: tensor<?x?x16x1xf32>, %arg1: tensor<?x?x16x1xf
// CHECK: }

// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>

#config2 = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
func.func @quantized_matmul(%arg0: tensor<2x4x128x16x1xi8>, %arg1: tensor<2x4x16xf32>, %arg2: tensor<2x4x16xf32>, %arg3: tensor<2x688x128x16x1xi8>, %arg4: tensor<2x688x16xf32>, %arg5: tensor<2x688x16xf32>) -> tensor<2x11008x64xf32> {
Expand Down Expand Up @@ -61,7 +64,7 @@ func.func @quantized_matmul(%arg0: tensor<2x4x128x16x1xi8>, %arg1: tensor<2x4x16
} -> tensor<2x688x128x16x1xf32>
%4 = tensor.empty() : tensor<2x4x688x16x16xf32>
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32>
%6 = linalg.batch_mmt4d {lowering_config = #config} ins(%1, %3 : tensor<2x4x128x16x1xf32>, tensor<2x688x128x16x1xf32>) outs(%5 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32>
%6 = linalg.batch_mmt4d {lowering_config = #config2} ins(%1, %3 : tensor<2x4x128x16x1xf32>, tensor<2x688x128x16x1xf32>) outs(%5 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32>
%7 = tensor.empty() : tensor<2x11008x64xf32>
%unpack = tensor.unpack %6 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 16] into %7 : tensor<2x4x688x16x16xf32> -> tensor<2x11008x64xf32>
return %unpack : tensor<2x11008x64xf32>
Expand All @@ -75,3 +78,42 @@ func.func @quantized_matmul(%arg0: tensor<2x4x128x16x1xi8>, %arg1: tensor<2x4x16
// CHECK: linalg.batch_mmt4d
// CHECK: tensor.unpack
// CHECK: }


// -----

#config3 = #iree_codegen.lowering_config<tile_sizes = [[0, 32, 0, 0, 0, 0], [1, 16, 1, 1, 0, 0], [0, 0, 0, 0, 1, 5], [0, 0, 0, 0, 0, 0]]>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @dequant_avgpool(%arg0: tensor<1x320x65x65xi8>) -> tensor<1x320x1x1xf32> {
%cst = arith.constant 1.250000e-01 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%c5408000 = arith.constant 5408000 : index
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<1x320x1x1xf32>
%1 = tensor.empty() : tensor<65x65xf32>
%2 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<65x65xf32>) -> tensor<65x65xf32>
%3 = tensor.empty() : tensor<1x320x65x65xf32>
%4 = tensor.empty() : tensor<1x320x1x1xf32>
%5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x320x65x65xi8>) outs(%3 : tensor<1x320x65x65xf32>) {
^bb0(%in: i8, %out: f32):
%7 = arith.extsi %in : i8 to i32
%8 = arith.sitofp %7 : i32 to f32
%9 = arith.mulf %8, %cst : f32
linalg.yield %9 : f32
} -> tensor<1x320x65x65xf32>
%6 = linalg.pooling_nchw_sum {lowering_config = #config3} ins(%5, %2 : tensor<1x320x65x65xf32>, tensor<65x65xf32>) outs(%4 : tensor<1x320x1x1xf32>) -> tensor<1x320x1x1xf32>
return %6 : tensor<1x320x1x1xf32>
}

// CHECK-REDUCTION-LABEL: func.func @dequant_avgpool(
// CHECK-REDUCTION-SAME: {
// CHECK-REDUCTION: scf.for
// CHECK-REDUCTION-SAME: {
// CHECK-REDUCTION: scf.for
// CHECK-REDUCTION-SAME: {
// CHECK-REDUCTION: linalg.generic
// CHECK-REDUCTION: %[[POOL:.+]] = linalg.pooling_nchw_sum
// CHECK-REDUCTION: scf.yield %[[POOL]]
// CHECK-REDUCTION: }
// CHECK-REDUCTION: }
// CHECK-REDUCTION: }

0 comments on commit ad2f0f8

Please sign in to comment.