forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Codegen] Add pass to normalize loop bounds (iree-org#17542)
This is based on the pass added to `iree-amd-aie` here from @jtuyls: nod-ai/iree-amd-aie@45c9465 This restricts the normalization pattern to cases where the induction variables are `index` type, however this is nearly always the case in typical codegen flows.
- Loading branch information
Showing
8 changed files
with
343 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
199 changes: 199 additions & 0 deletions
199
compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/Common/PassDetail.h" | ||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Arith/Utils/Utils.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/Dialect/Utils/StaticValueUtils.h" | ||
#include "mlir/IR/AffineExpr.h" | ||
#include "mlir/IR/OpDefinition.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
|
||
#define DEBUG_TYPE "iree-codegen-normalize-loop-bounds" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
static OpFoldResult emitNormalizedUpperBound(RewriterBase &rewriter, | ||
Location loc, OpFoldResult lb, | ||
OpFoldResult ub, | ||
OpFoldResult step) { | ||
AffineExpr d0, d1, d2; | ||
bindDims(rewriter.getContext(), d0, d1, d2); | ||
return affine::makeComposedFoldedAffineApply( | ||
rewriter, loc, (d0 - d1).ceilDiv(d2), {ub, lb, step}); | ||
} | ||
|
||
/// Helper structure for storing the newly computed loop bounds. | ||
namespace { | ||
struct LoopRanges { | ||
SmallVector<OpFoldResult> lowerBounds; | ||
SmallVector<OpFoldResult> upperBounds; | ||
SmallVector<OpFoldResult> steps; | ||
}; | ||
} // namespace | ||
|
||
static FailureOr<LoopRanges> | ||
emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, Block *body, | ||
ValueRange ivs, ArrayRef<OpFoldResult> lbs, | ||
ArrayRef<OpFoldResult> ubs, | ||
ArrayRef<OpFoldResult> steps) { | ||
Attribute zero = rewriter.getIndexAttr(0); | ||
Attribute one = rewriter.getIndexAttr(1); | ||
SmallVector<OpFoldResult> newLbs; | ||
SmallVector<OpFoldResult> newUbs; | ||
SmallVector<OpFoldResult> newSteps; | ||
for (auto &&[iv, lb, ub, step] : llvm::zip(ivs, lbs, ubs, steps)) { | ||
std::optional<int64_t> stepInt = getConstantIntValue(step); | ||
// Bail out on negative steps. | ||
if (!stepInt || stepInt.value() <= 0) { | ||
return failure(); | ||
} | ||
|
||
// The lower bound and step of a normalized loop is always zero/one. | ||
newLbs.push_back(zero); | ||
newSteps.push_back(one); | ||
|
||
// Compute the normalized upper bound. | ||
OpFoldResult newUb = emitNormalizedUpperBound(rewriter, loc, lb, ub, step); | ||
newUbs.push_back(newUb); | ||
|
||
// Compute and replace the denormalized loop iterator argument in the loop | ||
// body with an insertion guard. | ||
{ | ||
OpBuilder::InsertionGuard g(rewriter); | ||
rewriter.setInsertionPointToStart(body); | ||
AffineExpr idx, stepExpr, lbExpr; | ||
bindDims(rewriter.getContext(), idx, stepExpr, lbExpr); | ||
affine::AffineApplyOp denormalizedIV = affine::makeComposedAffineApply( | ||
rewriter, loc, idx * stepExpr + lbExpr, {iv, step, lb}); | ||
SmallPtrSet<Operation *, 2> preserve = {iv.getDefiningOp(), | ||
denormalizedIV}; | ||
rewriter.replaceAllUsesExcept(iv, denormalizedIV.getResult(), preserve); | ||
} | ||
} | ||
return LoopRanges{newLbs, newUbs, newSteps}; | ||
} | ||
|
||
/// Transform a `scf.for` loop with a strictly positive step | ||
/// for %i = %lb to %ub step %s | ||
/// into a 0-based loop with step 1 | ||
/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 | ||
/// Insert an `affine.apply` operation to compute the denormalized index value. | ||
static LogicalResult normalizeLoopBounds(RewriterBase &rewriter, | ||
scf::ForOp forOp) { | ||
OpBuilder::InsertionGuard g(rewriter); | ||
// Return if already normalized. | ||
std::optional<int64_t> lbInt = getConstantIntValue(forOp.getLowerBound()); | ||
std::optional<int64_t> stepInt = getConstantIntValue(forOp.getStep()); | ||
if (lbInt && stepInt && lbInt.value() == 0 && stepInt.value() == 1) { | ||
return success(); | ||
} | ||
|
||
// Bail out on non-index types because the affine applies that are generated | ||
// require it. | ||
if (!isa<IndexType>(forOp.getInductionVar().getType())) { | ||
return failure(); | ||
} | ||
|
||
Location loc = forOp.getLoc(); | ||
|
||
rewriter.setInsertionPoint(forOp); | ||
FailureOr<LoopRanges> newLoopParams = emitNormalizedLoopBounds( | ||
rewriter, loc, forOp.getBody(), forOp.getInductionVar(), | ||
getAsOpFoldResult(forOp.getLowerBound()), | ||
getAsOpFoldResult(forOp.getUpperBound()), | ||
getAsOpFoldResult(forOp.getStep())); | ||
if (failed(newLoopParams)) { | ||
return failure(); | ||
} | ||
|
||
assert(newLoopParams->lowerBounds.size() == 1 && | ||
newLoopParams->upperBounds.size() == 1 && | ||
newLoopParams->steps.size() == 1 && | ||
"expected single range for scf.for"); | ||
|
||
rewriter.modifyOpInPlace(forOp, [&]() { | ||
forOp.setLowerBound(getValueOrCreateConstantIndexOp( | ||
rewriter, loc, newLoopParams->lowerBounds.front())); | ||
forOp.setUpperBound(getValueOrCreateConstantIndexOp( | ||
rewriter, loc, newLoopParams->upperBounds.front())); | ||
forOp.setStep(getValueOrCreateConstantIndexOp( | ||
rewriter, loc, newLoopParams->steps.front())); | ||
}); | ||
return success(); | ||
} | ||
|
||
/// Transform a `scf.forall` loop with a strictly positive steps | ||
/// forall (%i, %j) = (%lb0, %lb1) to (%ub0, %ub1) step (%s0, %s1) | ||
/// into a 0-based loop with step 1 (normalized) | ||
/// forall (%i, %j) in (ceildiv(%ub0 - %lb0, %s0), ceildiv(%ub1 - %lb1, %s1)) | ||
/// Insert `affine.apply` operations to compute the denormalized index values. | ||
static LogicalResult normalizeLoopBounds(RewriterBase &rewriter, | ||
scf::ForallOp forallOp) { | ||
OpBuilder::InsertionGuard g(rewriter); | ||
if (forallOp.isNormalized()) | ||
return success(); | ||
|
||
// `scf.forall` requires that all lbs/ubs/steps/ivs are index type so no need | ||
// to check here. | ||
|
||
rewriter.setInsertionPoint(forallOp); | ||
FailureOr<LoopRanges> newLoopParams = emitNormalizedLoopBounds( | ||
rewriter, forallOp.getLoc(), forallOp.getBody(), | ||
forallOp.getInductionVars(), forallOp.getMixedLowerBound(), | ||
forallOp.getMixedUpperBound(), forallOp.getMixedStep()); | ||
if (failed(newLoopParams)) { | ||
return failure(); | ||
} | ||
|
||
rewriter.setInsertionPointAfter(forallOp); | ||
auto newLoop = rewriter.create<scf::ForallOp>( | ||
rewriter.getUnknownLoc(), newLoopParams->lowerBounds, | ||
newLoopParams->upperBounds, newLoopParams->steps, forallOp.getOutputs(), | ||
forallOp.getMapping()); | ||
rewriter.eraseOp(newLoop.getTerminator()); | ||
rewriter.mergeBlocks(forallOp.getBody(), newLoop.getBody(), | ||
newLoop.getBody()->getArguments()); | ||
rewriter.replaceOp(forallOp, newLoop); | ||
|
||
return success(); | ||
} | ||
|
||
namespace { | ||
struct NormalizeLoopBoundsPass | ||
: public NormalizeLoopBoundsPassBase<NormalizeLoopBoundsPass> { | ||
NormalizeLoopBoundsPass(bool nFor, bool nForall) | ||
: normalizeFor(nFor), normalizeForall(nForall) {} | ||
void runOnOperation() override { | ||
Operation *op = getOperation(); | ||
IRRewriter rewriter(op); | ||
if (normalizeFor) { | ||
op->walk([&](scf::ForOp forOp) { | ||
(void)normalizeLoopBounds(rewriter, forOp); | ||
}); | ||
} | ||
if (normalizeForall) { | ||
op->walk([&](scf::ForallOp forallOp) { | ||
(void)normalizeLoopBounds(rewriter, forallOp); | ||
}); | ||
} | ||
} | ||
|
||
private: | ||
bool normalizeFor; | ||
bool normalizeForall; | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<Pass> createNormalizeLoopBoundsPass(bool normalizeFor, | ||
bool normalizeForall) { | ||
return std::make_unique<NormalizeLoopBoundsPass>(normalizeFor, | ||
normalizeForall); | ||
} | ||
|
||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
compiler/src/iree/compiler/Codegen/Common/test/normalize_loop_bounds.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
|
||
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-normalize-loop-bounds, cse)" --allow-unregistered-dialect --verify-diagnostics %s | FileCheck %s | ||
module { | ||
func.func @for_normalize_step() { | ||
%c0 = arith.constant 0 : index | ||
%c2 = arith.constant 2 : index | ||
%c8 = arith.constant 8 : index | ||
scf.for %arg0 = %c0 to %c8 step %c2 { | ||
"iree.keep"(%arg0) : (index) -> () | ||
} | ||
return | ||
} | ||
} | ||
|
||
// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 2)> | ||
// CHECK-LABEL: func.func @for_normalize_step | ||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index | ||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index | ||
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index | ||
// CHECK: scf.for %[[ARG:.+]] = %[[C0]] to %[[C4]] step %[[C1]] | ||
// CHECK-NEXT: affine.apply #[[$MAP]](%[[ARG]]) | ||
|
||
// ----- | ||
|
||
module { | ||
func.func @for_normalize_lowerbound() { | ||
%c1 = arith.constant 1 : index | ||
%c2 = arith.constant 2 : index | ||
%c8 = arith.constant 8 : index | ||
scf.for %arg0 = %c2 to %c8 step %c1 { | ||
"iree.keep"(%arg0) : (index) -> () | ||
} | ||
return | ||
} | ||
} | ||
|
||
// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 + 2)> | ||
// CHECK-LABEL: func.func @for_normalize_lowerbound | ||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index | ||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index | ||
// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index | ||
// CHECK: scf.for %[[ARG:.+]] = %[[C0]] to %[[C6]] step %[[C1]] | ||
// CHECK-NEXT: affine.apply #[[$MAP]](%[[ARG]]) | ||
|
||
// ----- | ||
|
||
module { | ||
func.func @for_normalize_lowerbound_and_step() { | ||
%c1 = arith.constant 1 : index | ||
%c4 = arith.constant 4 : index | ||
%c13 = arith.constant 13 : index | ||
scf.for %arg0 = %c1 to %c13 step %c4 { | ||
"iree.keep"(%arg0) : (index) -> () | ||
} | ||
return | ||
} | ||
} | ||
|
||
// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 4 + 1)> | ||
// CHECK-LABEL: func.func @for_normalize_lowerbound_and_step | ||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index | ||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index | ||
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index | ||
// CHECK: scf.for %[[ARG:.+]] = %[[C0]] to %[[C3]] step %[[C1]] | ||
// CHECK-NEXT: affine.apply #[[$MAP]](%[[ARG]]) | ||
|
||
// ----- | ||
|
||
module { | ||
func.func @forall_normalize_step() { | ||
scf.forall (%arg0, %arg1) = (0, 0) to (8, 16) step (8, 8) { | ||
"iree.keep"(%arg0, %arg1) : (index, index) -> () | ||
} | ||
return | ||
} | ||
} | ||
|
||
// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 8)> | ||
// CHECK-LABEL: func.func @forall_normalize_step | ||
// CHECK: scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) in (1, 2) | ||
// CHECK-DAG: affine.apply #[[$MAP]](%[[ARG0]]) | ||
// CHECK-DAG: affine.apply #[[$MAP]](%[[ARG1]]) | ||
|
||
// ----- | ||
|
||
module { | ||
func.func @forall_normalize_lowerbound() { | ||
scf.forall (%arg0, %arg1) = (2, 4) to (8, 16) step (1, 1) { | ||
"iree.keep"(%arg0, %arg1) : (index, index) -> () | ||
} | ||
return | ||
} | ||
} | ||
|
||
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 + 4)> | ||
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 + 2)> | ||
// CHECK-LABEL: func.func @forall_normalize_lowerbound | ||
// CHECK: scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) in (6, 12) | ||
// CHECK-DAG: affine.apply #[[$MAP1]](%[[ARG0]]) | ||
// CHECK-DAG: affine.apply #[[$MAP0]](%[[ARG1]]) | ||
|
||
// ----- | ||
|
||
module { | ||
func.func @forall_normalize_lowerbound_and_step() { | ||
scf.forall (%arg0, %arg1) = (2, 4) to (8, 16) step (2, 4) { | ||
"iree.keep"(%arg0, %arg1) : (index, index) -> () | ||
} | ||
return | ||
} | ||
} | ||
|
||
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4 + 4)> | ||
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 2 + 2)> | ||
// CHECK-LABEL: func.func @forall_normalize_lowerbound | ||
// CHECK: scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) in (3, 3) | ||
// CHECK-DAG: affine.apply #[[$MAP1]](%[[ARG0]]) | ||
// CHECK-DAG: affine.apply #[[$MAP0]](%[[ARG1]]) |