forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Codegen][Tuner] Add pass to link tuning specs (iree-org#19281)
This pass is meant for combining multiple tuning specs (e.g., a user-provided one and a default one). We expect the input module to have nested sub-modules with named sequences marked with the `iree_codegen.tuning_spec_entrypoint` unit attributes. The pass collects all such tuning specs and introduce a new named sequence that includes all the other tuning spec entry points. The order of inclusion is the same as the in which these nested tuning specs appear in the IR. Issue: iree-org#19214
- Loading branch information
Showing
9 changed files
with
279 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
151 changes: 151 additions & 0 deletions
151
compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include <cassert> | ||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/ADT/SmallVectorExtras.h" | ||
#include "mlir/Dialect/Transform/IR/TransformAttrs.h" | ||
#include "mlir/Dialect/Transform/IR/TransformDialect.h" | ||
#include "mlir/Dialect/Transform/IR/TransformOps.h" | ||
#include "mlir/Dialect/Transform/IR/TransformTypes.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/Location.h" | ||
|
||
#define DEBUG_TYPE "iree-codegen-link-tuning-specs" | ||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") | ||
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
#define GEN_PASS_DEF_LINKTUNINGSPECSPASS | ||
#include "iree/compiler/Codegen/Common/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
using mlir::transform::NamedSequenceOp; | ||
|
||
static SmallVector<ModuleOp> | ||
findNestedModulesWithNamedSequences(ModuleOp module) { | ||
Block *body = module.getBody(); | ||
return llvm::to_vector( | ||
llvm::make_filter_range(body->getOps<ModuleOp>(), [](ModuleOp op) { | ||
return op.getSymName().has_value() && | ||
op->hasAttr( | ||
transform::TransformDialect::kWithNamedSequenceAttrName); | ||
})); | ||
} | ||
|
||
static SmallVector<NamedSequenceOp> findTuningSpecs(ModuleOp module) { | ||
Block *body = module.getBody(); | ||
return llvm::to_vector(llvm::make_filter_range( | ||
body->getOps<NamedSequenceOp>(), | ||
[](NamedSequenceOp op) { return op->hasAttr(kTuningSpecAttrName); })); | ||
} | ||
|
||
static LogicalResult validateTuningSpec(NamedSequenceOp op) { | ||
if (!op.getResultTypes().empty()) { | ||
op->emitWarning() << "Tuning spec expected to have no results"; | ||
return failure(); | ||
} | ||
|
||
ArrayRef<Type> argTypes = op.getArgumentTypes(); | ||
if (argTypes.size() != 1 || !isa<transform::AnyOpType>(argTypes[0])) { | ||
op->emitWarning() << "Tuning spec expected to have one argument of type " | ||
"'!transform.any_op'"; | ||
return failure(); | ||
} | ||
|
||
if (!op.getArgAttr(0, transform::TransformDialect::kArgReadOnlyAttrName)) { | ||
op->emitWarning() << "Tuning spec expected to have one readonly argument"; | ||
return failure(); | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
static NamedSequenceOp | ||
emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) { | ||
OpBuilder builder(module->getContext()); | ||
builder.setInsertionPointToEnd(module.getBody()); | ||
|
||
Location loc = builder.getFusedLoc(llvm::map_to_vector( | ||
specsToLink, [](NamedSequenceOp op) { return op->getLoc(); })); | ||
FunctionType specType = builder.getFunctionType( | ||
TypeRange{builder.getType<transform::AnyOpType>()}, TypeRange{}); | ||
auto newSpec = builder.create<NamedSequenceOp>( | ||
loc, kKernelConfigSpecName, TypeAttr::get(specType), | ||
/*sym_visibility=*/StringAttr{}, | ||
/*arg_attrs=*/ArrayAttr{}, | ||
/*res_attrs*/ ArrayAttr{}); | ||
newSpec.setArgAttr(0, transform::TransformDialect::kArgReadOnlyAttrName, | ||
builder.getUnitAttr()); | ||
newSpec->setAttr(kTuningSpecAttrName, builder.getUnitAttr()); | ||
|
||
Region ®ion = newSpec.getRegion(); | ||
Block *body = builder.createBlock(®ion, region.begin(), | ||
newSpec.getArgumentTypes(), loc); | ||
builder.setInsertionPointToStart(body); | ||
|
||
// Emit one `transform.include` op per child tuning spec. In the future, | ||
// we may want to switch to a custom transform op for this to perform | ||
// 'short-circuring' and apply at most one tuning spec. | ||
Value operand = body->getArgument(0); | ||
for (NamedSequenceOp spec : specsToLink) { | ||
ModuleOp parentModule = spec->getParentOfType<ModuleOp>(); | ||
assert(parentModule); | ||
StringAttr parentSymbol = parentModule.getSymNameAttr(); | ||
assert(parentSymbol); | ||
auto symbol = SymbolRefAttr::get( | ||
parentSymbol, FlatSymbolRefAttr::get(spec.getSymNameAttr())); | ||
|
||
// Surpress silenceable errors so that failures to match in child tuning | ||
// specs can be ignored. | ||
builder.create<transform::IncludeOp>( | ||
loc, TypeRange{}, symbol, transform::FailurePropagationMode::Suppress, | ||
operand); | ||
} | ||
|
||
builder.create<transform::YieldOp>(loc); | ||
return newSpec; | ||
} | ||
|
||
struct LinkTuningSpecsPass final | ||
: impl::LinkTuningSpecsPassBase<LinkTuningSpecsPass> { | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registerTransformDialectTranslationDependentDialects(registry); | ||
} | ||
|
||
void runOnOperation() override { | ||
ModuleOp module = getOperation(); | ||
SmallVector<NamedSequenceOp> tuningSpecs; | ||
|
||
for (ModuleOp nested : findNestedModulesWithNamedSequences(module)) { | ||
llvm::append_range(tuningSpecs, findTuningSpecs(nested)); | ||
} | ||
|
||
for (NamedSequenceOp spec : tuningSpecs) { | ||
LDBG("Found tuning spec: " << spec.getSymName()); | ||
if (failed(validateTuningSpec(spec))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
|
||
if (tuningSpecs.empty()) { | ||
LDBG("No tuning specs found, exiting without linking"); | ||
return; | ||
} | ||
|
||
emitLinkedTuningSpec(module, tuningSpecs); | ||
} | ||
}; | ||
|
||
} // namespace | ||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
99 changes: 99 additions & 0 deletions
99
compiler/src/iree/compiler/Codegen/Common/test/link_tuning_specs.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
// RUN: iree-opt %s --no-implicit-module --iree-codegen-link-tuning-specs --split-input-file \ | ||
// RUN: | FileCheck %s | ||
|
||
// CHECK-LABEL: module @td_module_0 | ||
// | ||
// CHECK: transform.named_sequence @outer_spec | ||
// | ||
// CHECK: transform.named_sequence @__kernel_config | ||
// CHECK-SAME: (%arg0: !transform.any_op {transform.readonly}) | ||
// CHECK-SAME: attributes {iree_codegen.tuning_spec_entrypoint} | ||
// CHECK: transform.include @foo_module::@foo failures(suppress) | ||
// CHECK-NEXT: transform.include @bar_module::@bar failures(suppress) | ||
// CHECK-NEXT: transform.include @baz_module::@baz failures(suppress) | ||
// CHECK-NEXT: transform.yield | ||
|
||
module @td_module_0 attributes { transform.with_named_sequence } { | ||
module @foo_module attributes { transform.with_named_sequence } { | ||
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () | ||
attributes { iree_codegen.tuning_spec_entrypoint } { | ||
transform.print {name = "Foo", skip_regions} | ||
transform.yield | ||
} | ||
} | ||
|
||
module @bar_module attributes { transform.with_named_sequence } { | ||
transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> () | ||
attributes { iree_codegen.tuning_spec_entrypoint } { | ||
transform.match.operation_name %arg0 ["func.func"] : !transform.any_op | ||
transform.print {name = "Bar", skip_regions} | ||
transform.yield | ||
} | ||
} | ||
|
||
module @baz_module attributes { transform.with_named_sequence } { | ||
transform.named_sequence @baz(%arg0: !transform.any_op {transform.readonly}) -> () | ||
attributes { iree_codegen.tuning_spec_entrypoint } { | ||
transform.print {name = "Baz", skip_regions} | ||
transform.yield | ||
} | ||
} | ||
|
||
transform.named_sequence @outer_spec(%module: !transform.any_op {transform.readonly}) -> () | ||
attributes { iree_codegen.tuning_spec_entrypoint } { | ||
transform.yield | ||
} | ||
} | ||
|
||
|
||
// ----- | ||
|
||
// Here, `foo` shouldn't be included because it's not marked with `tuning_spec_entrypoint`. | ||
|
||
// CHECK-LABEL: module @td_module_1 | ||
// CHECK: @foo_module | ||
// CHECK: @__kernel_config | ||
// CHECK-NOT transform.include @foo_module::@foo failures(suppress) (%arg0) : (!transform.any_op) -> () | ||
// CHECK: transform.include @foo_module::@bar failures(suppress) (%arg0) : (!transform.any_op) -> () | ||
// CHECK-NEXT: transform.yield | ||
|
||
module @td_module_1 attributes { transform.with_named_sequence } { | ||
module @foo_module attributes { transform.with_named_sequence } { | ||
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () { | ||
transform.yield | ||
} | ||
transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> () | ||
attributes { iree_codegen.tuning_spec_entrypoint } { | ||
transform.yield | ||
} | ||
func.func @baz(%arg0: i32) -> () { | ||
return | ||
} | ||
} | ||
} | ||
|
||
|
||
// ----- | ||
|
||
// Make sure we do not crash on modules with no tuning specs. | ||
|
||
// CHECK-LABEL: module @td_module_2 | ||
// CHECK-NOT: @__kernel_config | ||
module @td_module_2 attributes { transform.with_named_sequence } {} | ||
|
||
// ----- | ||
|
||
// Make sure we do not crash on unnamed nested modules. | ||
|
||
// CHECK-LABEL: module @td_module_3 | ||
// CHECK: transform.named_sequence @foo | ||
// CHECK-NOT: @__kernel_config | ||
|
||
module @td_module_3 attributes { transform.with_named_sequence } { | ||
module attributes { transform.with_named_sequence } { | ||
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () | ||
attributes { iree_codegen.tuning_spec_entrypoint } { | ||
transform.yield | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters