Skip to content

Commit

Permalink
[GlobalOpt] Switch to new pass generation tablegen definitions. (iree…
Browse files Browse the repository at this point in the history
…-org#18163)

This is mostly an NFC change. The revision applies a little cleanups:

- Remove `enableQuantizedMatmulReassociation` option from
FuseDequantizationMatmulPass. It should be controled by pipeline.
- Move testing options to tablegen definitions for
PropagateLinalgTransposePass
- Switch a couple of passes to follow `create.*Pass` naming convention.
- Switch namespaces to the new single-line syntax for
FuseSiluHorizontalMatmulPass

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW authored Aug 8, 2024
1 parent fbf677d commit 2695fe9
Show file tree
Hide file tree
Showing 27 changed files with 183 additions and 371 deletions.
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ iree_gentbl_cc_library(
iree_compiler_cc_library(
name = "PassHeaders",
hdrs = [
"PassDetail.h",
"Passes.h",
"Passes.h.inc",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ iree_cc_library(
NAME
PassHeaders
HDRS
"PassDetail.h"
"Passes.h"
"Passes.h.inc"
DEPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"

namespace mlir::iree_compiler::GlobalOptimization {

#define GEN_PASS_DEF_CLEANUPNUMERICNARROWINGPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"

namespace {

class CleanupNumericNarrowingPass
: public CleanupNumericNarrowingBase<CleanupNumericNarrowingPass> {
: public impl::CleanupNumericNarrowingPassBase<
CleanupNumericNarrowingPass> {
void runOnOperation() override {
getOperation()->walk([](IREE::Util::NumericOptionalNarrowOp op) {
op.getResult().replaceAllUsesWith(op.getOperand());
Expand All @@ -23,9 +26,4 @@ class CleanupNumericNarrowingPass
};

} // namespace

std::unique_ptr<Pass> createCleanupNumericNarrowingPass() {
return std::make_unique<CleanupNumericNarrowingPass>();
}

} // namespace mlir::iree_compiler::GlobalOptimization
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand All @@ -14,6 +13,9 @@

namespace mlir::iree_compiler::GlobalOptimization {

#define GEN_PASS_DEF_CONVERT1X1FILTERCONV2DTOMATMULPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"

namespace {

// Converts linalg.conv_2d_input_nhwc_filter_nhwc op to linalg.matmul
Expand Down Expand Up @@ -157,7 +159,7 @@ class Convert1x1FilterConvToMatmul : public OpRewritePattern<Conv2DOpType> {
};

struct Convert1X1FilterConv2DToMatmulPass
: public Convert1X1FilterConv2DToMatmulBase<
: public impl::Convert1X1FilterConv2DToMatmulPassBase<
Convert1X1FilterConv2DToMatmulPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
Expand All @@ -176,9 +178,4 @@ struct Convert1X1FilterConv2DToMatmulPass
}
};
} // namespace

std::unique_ptr<Pass> createConvert1X1FilterConv2DToMatmulPass() {
return std::make_unique<Convert1X1FilterConv2DToMatmulPass>();
}

} // namespace mlir::iree_compiler::GlobalOptimization
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

namespace mlir::iree_compiler::GlobalOptimization {

#define GEN_PASS_DEF_DATALAYOUTPROPAGATIONPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"

namespace {

struct DataLayoutPropagationPass
: public DataLayoutPropagationBase<DataLayoutPropagationPass> {
: public impl::DataLayoutPropagationPassBase<DataLayoutPropagationPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
FunctionOpInterface funcOp = getOperation();
Expand All @@ -43,9 +43,4 @@ struct DataLayoutPropagationPass
};

} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createDataLayoutPropagationPass() {
return std::make_unique<DataLayoutPropagationPass>();
}
} // namespace mlir::iree_compiler::GlobalOptimization
15 changes: 10 additions & 5 deletions compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -16,6 +15,9 @@

namespace mlir::iree_compiler::GlobalOptimization {

#define GEN_PASS_DEF_DECOMPOSECONCATPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"

namespace {

static Value createTranspose(OpBuilder &builder, Value source,
Expand Down Expand Up @@ -78,13 +80,16 @@ struct TransposeInnerConcatenation : public OpRewritePattern<tensor::ConcatOp> {
}
};

struct DecomposeConcatPass : public DecomposeConcatBase<DecomposeConcatPass> {
struct DecomposeConcatPass
: public impl::DecomposeConcatPassBase<DecomposeConcatPass> {
using impl::DecomposeConcatPassBase<
DecomposeConcatPass>::DecomposeConcatPassBase;
explicit DecomposeConcatPass(bool enableConcatTransposition) {
this->enableConcatTransposition = enableConcatTransposition;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
}
DecomposeConcatPass(bool enableConcatTransposition) {
this->enableConcatTransposition = enableConcatTransposition;
}
DecomposeConcatPass(const DecomposeConcatPass &pass)
: DecomposeConcatPass(pass.enableConcatTransposition) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand All @@ -19,6 +18,9 @@

namespace mlir::iree_compiler::GlobalOptimization {

#define GEN_PASS_DEF_DEMOTECONTRACTIONINPUTSTOBF16PASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"

namespace {

// For narrowable inputs, selects
Expand Down Expand Up @@ -133,12 +135,13 @@ struct DemoteContractionInputsToBF16Pattern
};

class DemoteContractionInputsToBF16Pass
: public DemoteContractionInputsToBF16Base<
: public impl::DemoteContractionInputsToBF16PassBase<
DemoteContractionInputsToBF16Pass> {

public:
using impl::DemoteContractionInputsToBF16PassBase<
DemoteContractionInputsToBF16Pass>::DemoteContractionInputsToBF16PassBase;
explicit DemoteContractionInputsToBF16Pass(const DemotionOption &option) {
this->demoteOnly.setValue(option);
this->demoteOnly = option;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -27,6 +26,9 @@

namespace mlir::iree_compiler::GlobalOptimization {

#define GEN_PASS_DEF_DETACHELEMENTWISEFROMNAMEDOPSPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"

namespace {

struct DetachElementwisePattern
Expand Down Expand Up @@ -185,7 +187,7 @@ struct DetachSplatConstantOutsOperands
};

struct DetachElementwiseFromNamedOpsPass
: public DetachElementwiseFromNamedOpsBase<
: public impl::DetachElementwiseFromNamedOpsPassBase<
DetachElementwiseFromNamedOpsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, linalg::LinalgDialect,
Expand All @@ -206,9 +208,4 @@ struct DetachElementwiseFromNamedOpsPass
};

} // namespace

std::unique_ptr<Pass> createDetachElementwiseFromNamedOpsPass() {
return std::make_unique<DetachElementwiseFromNamedOpsPass>();
}

} // namespace mlir::iree_compiler::GlobalOptimization
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -14,9 +13,13 @@

namespace mlir::iree_compiler::GlobalOptimization {

#define GEN_PASS_DEF_ERASEUNUSEDLINALGOPERANDSPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"

namespace {
struct EraseUnusedLinalgOperandsPass
: public EraseUnusedLinalgOperandsBase<EraseUnusedLinalgOperandsPass> {
: public impl::EraseUnusedLinalgOperandsPassBase<
EraseUnusedLinalgOperandsPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
Expand All @@ -28,10 +31,4 @@ struct EraseUnusedLinalgOperandsPass
}
};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>>
createEraseUnusedLinalgOperands() {
return std::make_unique<EraseUnusedLinalgOperandsPass>();
}

} // namespace mlir::iree_compiler::GlobalOptimization
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/Utils/IntegerSet.h"
#include "llvm/ADT/BreadthFirstIterator.h"
Expand All @@ -29,6 +28,10 @@
#define DEBUG_TYPE "iree-global-opt-expand-tensor-shapes"

namespace mlir::iree_compiler::GlobalOptimization {

#define GEN_PASS_DEF_EXPANDTENSORSHAPESPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"

namespace {

// TODO(benvanik): factor out into a generic util pass base that lets us share
Expand Down Expand Up @@ -624,10 +627,8 @@ static void expandTensorDims(Operation *op, SymbolTable &symbolTable,
// results are always wrapped in a flow.tensor.tie_shape, with the
// elision/deduplication/etc left until cleanup.
class ExpandTensorShapesPass
: public ExpandTensorShapesBase<ExpandTensorShapesPass> {
: public impl::ExpandTensorShapesPassBase<ExpandTensorShapesPass> {
public:
ExpandTensorShapesPass() = default;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mlir::arith::ArithDialect>();
registry.insert<IREE::Flow::FlowDialect>();
Expand Down Expand Up @@ -661,9 +662,4 @@ class ExpandTensorShapesPass
};

} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createExpandTensorShapesPass() {
return std::make_unique<ExpandTensorShapesPass>();
}

} // namespace mlir::iree_compiler::GlobalOptimization
Loading

0 comments on commit 2695fe9

Please sign in to comment.