diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel index efc338fba873..c9cfaf2ebf9f 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel @@ -89,6 +89,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions:LinalgExtExtensions", "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms", "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils", + "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Util/Analysis", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Dialect/Util/Transforms", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index f448835eeabe..c69fe6aad772 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -116,6 +116,7 @@ iree_cc_library( iree::compiler::Dialect::LinalgExt::TransformExtensions::LinalgExtExtensions iree::compiler::Dialect::LinalgExt::Transforms iree::compiler::Dialect::LinalgExt::Utils + iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Util::Analysis iree::compiler::Dialect::Util::IR iree::compiler::Dialect::Util::Transforms diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp index f64e419ef906..390e7f6aee52 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp @@ -13,7 +13,10 @@ #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" #include "iree/compiler/Dialect/Util/Analysis/Explorer.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -22,6 +25,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#define DEBUG_TYPE "iree-flow-fold-unit-extent-dims" + namespace mlir::iree_compiler::IREE::Flow { #define GEN_PASS_DEF_FOLDUNITEXTENTDIMSPASS @@ -46,15 +51,24 @@ foldUnitDimsOnGlobal(IRRewriter &rewriter, IREE::Util::GlobalOpInterface global, } auto newGlobalType = globalType.clone(newShape); auto initialValue = global.getGlobalInitialValue(); - // TODO: Handle non-uninitialized cases. - auto uninitializedAttr = - llvm::dyn_cast_if_present(initialValue); - if (initialValue && !uninitializedAttr) + if (!initialValue) + return success(); + // TODO: Handle other cases + auto newInitialValue = + llvm::TypeSwitch(initialValue) + .Case([&](Attribute) { + return IREE::Util::UninitializedAttr::get(rewriter.getContext(), + newGlobalType); + }) + .Case( + [&](IREE::Stream::NamedParameterAttr attr) { + return IREE::Stream::NamedParameterAttr::get( + rewriter.getContext(), newGlobalType, attr.getScope(), + attr.getKey(), attr.getConfig()); + }) + .Default([&](Attribute) { return nullptr; }); + if (!newInitialValue) { return success(); - TypedAttr newInitialValue; - if (initialValue) { - newInitialValue = IREE::Util::UninitializedAttr::get(rewriter.getContext(), - newGlobalType); } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(global); @@ -101,19 +115,19 @@ struct FoldUnitExtentDimsPass } // namespace void FoldUnitExtentDimsPass::runOnOperation() { - auto funcOp = getOperation(); + auto moduleOp = getOperation(); MLIRContext *context = &getContext(); - Explorer explorer(funcOp, TraversalAction::RECURSE); + SymbolTable moduleSymbols(moduleOp); + Explorer explorer(moduleOp, TraversalAction::RECURSE); explorer.initialize(); IRRewriter rewriter(context); - SymbolTable moduleSymbols(funcOp); // Fold unit dims of GlobalOpInterface ops. explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) { IREE::Util::GlobalOpInterface global = globalInfo->op; auto tensorType = dyn_cast(global.getGlobalType()); - if (!tensorType || !global.isGlobalPrivate() || !global.isGlobalMutable()) { + if (!tensorType || !global.isGlobalPrivate()) { return; } if (llvm::none_of(tensorType.getShape(), @@ -142,7 +156,7 @@ void FoldUnitExtentDimsPass::runOnOperation() { }; linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options); linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, + if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(foldUnitDimsPatterns)))) { return signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir index 611c2da1b437..e652c5e39a82 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir @@ -91,3 +91,18 @@ module @no_fold_public { // CHECK: util.func public @no_fold_global_unit_dims // CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32> // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]] + +// ----- + +module @fold_stream_parameter { + util.global private mutable @global = #stream.parameter.named<"module"::"global"> : tensor<1x1x10xf32> + util.func public @fold_stream_parameter() -> tensor<1x1x10xf32> { + %global = util.global.load @global : tensor<1x1x10xf32> + util.return %global : tensor<1x1x10xf32> + } +} + +// CHECK: module @fold_stream_parameter +// CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32> +// CHECK: util.func public @fold_stream_parameter +// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32>