Skip to content

Commit

Permalink
[Flow][Global Opt] Fold unit dims of stream.parameter.named (iree-o…
Browse files Browse the repository at this point in the history
…rg#17824)

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 authored Jul 9, 2024
1 parent 3f6bf8c commit e24ea82
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions:LinalgExtExtensions",
"//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/Analysis",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ iree_cc_library(
iree::compiler::Dialect::LinalgExt::TransformExtensions::LinalgExtExtensions
iree::compiler::Dialect::LinalgExt::Transforms
iree::compiler::Dialect::LinalgExt::Utils
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::Analysis
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -22,6 +25,8 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-flow-fold-unit-extent-dims"

namespace mlir::iree_compiler::IREE::Flow {

#define GEN_PASS_DEF_FOLDUNITEXTENTDIMSPASS
Expand All @@ -46,15 +51,24 @@ foldUnitDimsOnGlobal(IRRewriter &rewriter, IREE::Util::GlobalOpInterface global,
}
auto newGlobalType = globalType.clone(newShape);
auto initialValue = global.getGlobalInitialValue();
// TODO: Handle non-uninitialized cases.
auto uninitializedAttr =
llvm::dyn_cast_if_present<IREE::Util::UninitializedAttr>(initialValue);
if (initialValue && !uninitializedAttr)
if (!initialValue)
return success();
// TODO: Handle other cases
auto newInitialValue =
llvm::TypeSwitch<Attribute, Attribute>(initialValue)
.Case<IREE::Util::UninitializedAttr>([&](Attribute) {
return IREE::Util::UninitializedAttr::get(rewriter.getContext(),
newGlobalType);
})
.Case<IREE::Stream::NamedParameterAttr>(
[&](IREE::Stream::NamedParameterAttr attr) {
return IREE::Stream::NamedParameterAttr::get(
rewriter.getContext(), newGlobalType, attr.getScope(),
attr.getKey(), attr.getConfig());
})
.Default([&](Attribute) { return nullptr; });
if (!newInitialValue) {
return success();
TypedAttr newInitialValue;
if (initialValue) {
newInitialValue = IREE::Util::UninitializedAttr::get(rewriter.getContext(),
newGlobalType);
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(global);
Expand Down Expand Up @@ -101,19 +115,19 @@ struct FoldUnitExtentDimsPass
} // namespace

void FoldUnitExtentDimsPass::runOnOperation() {
auto funcOp = getOperation();
auto moduleOp = getOperation();
MLIRContext *context = &getContext();

Explorer explorer(funcOp, TraversalAction::RECURSE);
SymbolTable moduleSymbols(moduleOp);
Explorer explorer(moduleOp, TraversalAction::RECURSE);
explorer.initialize();
IRRewriter rewriter(context);
SymbolTable moduleSymbols(funcOp);

// Fold unit dims of GlobalOpInterface ops.
explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) {
IREE::Util::GlobalOpInterface global = globalInfo->op;
auto tensorType = dyn_cast<RankedTensorType>(global.getGlobalType());
if (!tensorType || !global.isGlobalPrivate() || !global.isGlobalMutable()) {
if (!tensorType || !global.isGlobalPrivate()) {
return;
}
if (llvm::none_of(tensorType.getShape(),
Expand Down Expand Up @@ -142,7 +156,7 @@ void FoldUnitExtentDimsPass::runOnOperation() {
};
linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options);
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
if (failed(applyPatternsAndFoldGreedily(funcOp,
if (failed(applyPatternsAndFoldGreedily(moduleOp,
std::move(foldUnitDimsPatterns)))) {
return signalPassFailure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,18 @@ module @no_fold_public {
// CHECK: util.func public @no_fold_global_unit_dims
// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]]

// -----

module @fold_stream_parameter {
util.global private mutable @global = #stream.parameter.named<"module"::"global"> : tensor<1x1x10xf32>
util.func public @fold_stream_parameter() -> tensor<1x1x10xf32> {
%global = util.global.load @global : tensor<1x1x10xf32>
util.return %global : tensor<1x1x10xf32>
}
}

// CHECK: module @fold_stream_parameter
// CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32>
// CHECK: util.func public @fold_stream_parameter
// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32>

0 comments on commit e24ea82

Please sign in to comment.