Skip to content

Commit

Permalink
[Codegen] Add pass to normalize loop bounds (iree-org#17542)
Browse files Browse the repository at this point in the history
This is based on the pass added to `iree-amd-aie` here from @jtuyls:


nod-ai/iree-amd-aie@45c9465

This restricts the normalization pattern to cases where the induction
variables are `index` type, however this is nearly always the case in
typical codegen flows.
  • Loading branch information
qedawkins authored Jun 3, 2024
1 parent 7b319cb commit 14fd6ac
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ iree_compiler_cc_library(
"MaterializeEncodingIntoNop.cpp",
"MaterializeEncodingIntoPackUnPack.cpp",
"MemrefCopyToLinalg.cpp",
"NormalizeLoopBounds.cpp",
"OptimizeTensorInsertExtractSlices.cpp",
"OptimizeVectorTransferPass.cpp",
"PadDynamicAlloc.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ iree_cc_library(
"MaterializeEncodingIntoNop.cpp"
"MaterializeEncodingIntoPackUnPack.cpp"
"MemrefCopyToLinalg.cpp"
"NormalizeLoopBounds.cpp"
"OptimizeTensorInsertExtractSlices.cpp"
"OptimizeVectorTransferPass.cpp"
"PadDynamicAlloc.cpp"
Expand Down
199 changes: 199 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LogicalResult.h"

#define DEBUG_TYPE "iree-codegen-normalize-loop-bounds"

namespace mlir::iree_compiler {

static OpFoldResult emitNormalizedUpperBound(RewriterBase &rewriter,
Location loc, OpFoldResult lb,
OpFoldResult ub,
OpFoldResult step) {
AffineExpr d0, d1, d2;
bindDims(rewriter.getContext(), d0, d1, d2);
return affine::makeComposedFoldedAffineApply(
rewriter, loc, (d0 - d1).ceilDiv(d2), {ub, lb, step});
}

/// Helper structure for storing the newly computed loop bounds.
namespace {
struct LoopRanges {
SmallVector<OpFoldResult> lowerBounds;
SmallVector<OpFoldResult> upperBounds;
SmallVector<OpFoldResult> steps;
};
} // namespace

static FailureOr<LoopRanges>
emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, Block *body,
ValueRange ivs, ArrayRef<OpFoldResult> lbs,
ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> steps) {
Attribute zero = rewriter.getIndexAttr(0);
Attribute one = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> newLbs;
SmallVector<OpFoldResult> newUbs;
SmallVector<OpFoldResult> newSteps;
for (auto &&[iv, lb, ub, step] : llvm::zip(ivs, lbs, ubs, steps)) {
std::optional<int64_t> stepInt = getConstantIntValue(step);
// Bail out on negative steps.
if (!stepInt || stepInt.value() <= 0) {
return failure();
}

// The lower bound and step of a normalized loop is always zero/one.
newLbs.push_back(zero);
newSteps.push_back(one);

// Compute the normalized upper bound.
OpFoldResult newUb = emitNormalizedUpperBound(rewriter, loc, lb, ub, step);
newUbs.push_back(newUb);

// Compute and replace the denormalized loop iterator argument in the loop
// body with an insertion guard.
{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(body);
AffineExpr idx, stepExpr, lbExpr;
bindDims(rewriter.getContext(), idx, stepExpr, lbExpr);
affine::AffineApplyOp denormalizedIV = affine::makeComposedAffineApply(
rewriter, loc, idx * stepExpr + lbExpr, {iv, step, lb});
SmallPtrSet<Operation *, 2> preserve = {iv.getDefiningOp(),
denormalizedIV};
rewriter.replaceAllUsesExcept(iv, denormalizedIV.getResult(), preserve);
}
}
return LoopRanges{newLbs, newUbs, newSteps};
}

/// Transform a `scf.for` loop with a strictly positive step
/// for %i = %lb to %ub step %s
/// into a 0-based loop with step 1
/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1
/// Insert an `affine.apply` operation to compute the denormalized index value.
static LogicalResult normalizeLoopBounds(RewriterBase &rewriter,
scf::ForOp forOp) {
OpBuilder::InsertionGuard g(rewriter);
// Return if already normalized.
std::optional<int64_t> lbInt = getConstantIntValue(forOp.getLowerBound());
std::optional<int64_t> stepInt = getConstantIntValue(forOp.getStep());
if (lbInt && stepInt && lbInt.value() == 0 && stepInt.value() == 1) {
return success();
}

// Bail out on non-index types because the affine applies that are generated
// require it.
if (!isa<IndexType>(forOp.getInductionVar().getType())) {
return failure();
}

Location loc = forOp.getLoc();

rewriter.setInsertionPoint(forOp);
FailureOr<LoopRanges> newLoopParams = emitNormalizedLoopBounds(
rewriter, loc, forOp.getBody(), forOp.getInductionVar(),
getAsOpFoldResult(forOp.getLowerBound()),
getAsOpFoldResult(forOp.getUpperBound()),
getAsOpFoldResult(forOp.getStep()));
if (failed(newLoopParams)) {
return failure();
}

assert(newLoopParams->lowerBounds.size() == 1 &&
newLoopParams->upperBounds.size() == 1 &&
newLoopParams->steps.size() == 1 &&
"expected single range for scf.for");

rewriter.modifyOpInPlace(forOp, [&]() {
forOp.setLowerBound(getValueOrCreateConstantIndexOp(
rewriter, loc, newLoopParams->lowerBounds.front()));
forOp.setUpperBound(getValueOrCreateConstantIndexOp(
rewriter, loc, newLoopParams->upperBounds.front()));
forOp.setStep(getValueOrCreateConstantIndexOp(
rewriter, loc, newLoopParams->steps.front()));
});
return success();
}

/// Transform a `scf.forall` loop with a strictly positive steps
/// forall (%i, %j) = (%lb0, %lb1) to (%ub0, %ub1) step (%s0, %s1)
/// into a 0-based loop with step 1 (normalized)
/// forall (%i, %j) in (ceildiv(%ub0 - %lb0, %s0), ceildiv(%ub1 - %lb1, %s1))
/// Insert `affine.apply` operations to compute the denormalized index values.
static LogicalResult normalizeLoopBounds(RewriterBase &rewriter,
scf::ForallOp forallOp) {
OpBuilder::InsertionGuard g(rewriter);
if (forallOp.isNormalized())
return success();

// `scf.forall` requires that all lbs/ubs/steps/ivs are index type so no need
// to check here.

rewriter.setInsertionPoint(forallOp);
FailureOr<LoopRanges> newLoopParams = emitNormalizedLoopBounds(
rewriter, forallOp.getLoc(), forallOp.getBody(),
forallOp.getInductionVars(), forallOp.getMixedLowerBound(),
forallOp.getMixedUpperBound(), forallOp.getMixedStep());
if (failed(newLoopParams)) {
return failure();
}

rewriter.setInsertionPointAfter(forallOp);
auto newLoop = rewriter.create<scf::ForallOp>(
rewriter.getUnknownLoc(), newLoopParams->lowerBounds,
newLoopParams->upperBounds, newLoopParams->steps, forallOp.getOutputs(),
forallOp.getMapping());
rewriter.eraseOp(newLoop.getTerminator());
rewriter.mergeBlocks(forallOp.getBody(), newLoop.getBody(),
newLoop.getBody()->getArguments());
rewriter.replaceOp(forallOp, newLoop);

return success();
}

namespace {
struct NormalizeLoopBoundsPass
: public NormalizeLoopBoundsPassBase<NormalizeLoopBoundsPass> {
NormalizeLoopBoundsPass(bool nFor, bool nForall)
: normalizeFor(nFor), normalizeForall(nForall) {}
void runOnOperation() override {
Operation *op = getOperation();
IRRewriter rewriter(op);
if (normalizeFor) {
op->walk([&](scf::ForOp forOp) {
(void)normalizeLoopBounds(rewriter, forOp);
});
}
if (normalizeForall) {
op->walk([&](scf::ForallOp forallOp) {
(void)normalizeLoopBounds(rewriter, forallOp);
});
}
}

private:
bool normalizeFor;
bool normalizeForall;
};
} // namespace

std::unique_ptr<Pass> createNormalizeLoopBoundsPass(bool normalizeFor,
bool normalizeForall) {
return std::make_unique<NormalizeLoopBoundsPass>(normalizeFor,
normalizeForall);
}

} // namespace mlir::iree_compiler
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ createMemrefCopyToLinalgPass();
/// Extracts lowering configs and translation info from user configs.
std::unique_ptr<OperationPass<ModuleOp>> createMaterializeUserConfigsPass();

/// Normalizes the iteration range of `scf.for` and `scf.forall` loops to
/// [0, ub) += 1.
std::unique_ptr<Pass>
createNormalizeLoopBoundsPass(bool normalizeFor = true,
bool normalizeForall = true);

/// Pass to optimize vector transfer_read and transfer_write.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createOptimizeVectorTransferPass(bool flatten = false);
Expand Down
16 changes: 16 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,22 @@ def MemrefCopyToLinalgPass :
"mlir::iree_compiler::createMemrefCopyToLinalgPass()";
}

def NormalizeLoopBoundsPass :
Pass<"iree-codegen-normalize-loop-bounds", ""> {
let summary = "Normalize the loop bounds of `scf.for` and `scf.forall`";
let constructor = "mlir::iree_compiler::createNormalizeLoopBoundsPass()";
let options = [
Option<"normalizeFor", "normalize-for", "bool", "true",
"Enable normalization for `scf.for` loops">,
Option<"normalizeForall", "normalize-forall", "bool", "true",
"Enable normalization for `scf.forall` loops">,
];
let dependentDialects = [
"affine::AffineDialect",
"arith::ArithDialect"
];
}

def OptimizeVectorTransfer :
InterfacePass<"iree-codegen-optimize-vector-transfer", "mlir::FunctionOpInterface"> {
let summary =
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ iree_lit_test_suite(
"lower_ukernel_to_calls.mlir",
"materialize_encoding_into_nop.mlir",
"materialize_user_configs.mlir",
"normalize_loop_bounds.mlir",
"optimize_tensor_insert_extract_slices.mlir",
"pad_dynamic_alloc.mlir",
"polynomial_approximation.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ iree_lit_test_suite(
"lower_ukernel_to_calls.mlir"
"materialize_encoding_into_nop.mlir"
"materialize_user_configs.mlir"
"normalize_loop_bounds.mlir"
"optimize_tensor_insert_extract_slices.mlir"
"pad_dynamic_alloc.mlir"
"polynomial_approximation.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@

// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-normalize-loop-bounds, cse)" --allow-unregistered-dialect --verify-diagnostics %s | FileCheck %s
module {
func.func @for_normalize_step() {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c8 = arith.constant 8 : index
scf.for %arg0 = %c0 to %c8 step %c2 {
"iree.keep"(%arg0) : (index) -> ()
}
return
}
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-LABEL: func.func @for_normalize_step
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK: scf.for %[[ARG:.+]] = %[[C0]] to %[[C4]] step %[[C1]]
// CHECK-NEXT: affine.apply #[[$MAP]](%[[ARG]])

// -----

module {
func.func @for_normalize_lowerbound() {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c8 = arith.constant 8 : index
scf.for %arg0 = %c2 to %c8 step %c1 {
"iree.keep"(%arg0) : (index) -> ()
}
return
}
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 + 2)>
// CHECK-LABEL: func.func @for_normalize_lowerbound
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index
// CHECK: scf.for %[[ARG:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
// CHECK-NEXT: affine.apply #[[$MAP]](%[[ARG]])

// -----

module {
func.func @for_normalize_lowerbound_and_step() {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c13 = arith.constant 13 : index
scf.for %arg0 = %c1 to %c13 step %c4 {
"iree.keep"(%arg0) : (index) -> ()
}
return
}
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 4 + 1)>
// CHECK-LABEL: func.func @for_normalize_lowerbound_and_step
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[ARG:.+]] = %[[C0]] to %[[C3]] step %[[C1]]
// CHECK-NEXT: affine.apply #[[$MAP]](%[[ARG]])

// -----

module {
func.func @forall_normalize_step() {
scf.forall (%arg0, %arg1) = (0, 0) to (8, 16) step (8, 8) {
"iree.keep"(%arg0, %arg1) : (index, index) -> ()
}
return
}
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 8)>
// CHECK-LABEL: func.func @forall_normalize_step
// CHECK: scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) in (1, 2)
// CHECK-DAG: affine.apply #[[$MAP]](%[[ARG0]])
// CHECK-DAG: affine.apply #[[$MAP]](%[[ARG1]])

// -----

module {
func.func @forall_normalize_lowerbound() {
scf.forall (%arg0, %arg1) = (2, 4) to (8, 16) step (1, 1) {
"iree.keep"(%arg0, %arg1) : (index, index) -> ()
}
return
}
}

// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 + 4)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 + 2)>
// CHECK-LABEL: func.func @forall_normalize_lowerbound
// CHECK: scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) in (6, 12)
// CHECK-DAG: affine.apply #[[$MAP1]](%[[ARG0]])
// CHECK-DAG: affine.apply #[[$MAP0]](%[[ARG1]])

// -----

module {
func.func @forall_normalize_lowerbound_and_step() {
scf.forall (%arg0, %arg1) = (2, 4) to (8, 16) step (2, 4) {
"iree.keep"(%arg0, %arg1) : (index, index) -> ()
}
return
}
}

// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4 + 4)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 2 + 2)>
// CHECK-LABEL: func.func @forall_normalize_lowerbound
// CHECK: scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) in (3, 3)
// CHECK-DAG: affine.apply #[[$MAP1]](%[[ARG0]])
// CHECK-DAG: affine.apply #[[$MAP0]](%[[ARG1]])

0 comments on commit 14fd6ac

Please sign in to comment.