Skip to content

Commit

Permalink
Make vector masking optionally inference only
Browse files Browse the repository at this point in the history
  • Loading branch information
qedawkins committed Nov 10, 2023
1 parent 7df3702 commit 8c727b1
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ inferVectorSizesFromIR(linalg::LinalgOp linalgOp) {
// Return the vector sizes from the local lowering config or try to infer them
// from the tensor shapes and tiled loops in the IR.
static FailureOr<SizesAndScalableFlags>
getVectorSizes(linalg::LinalgOp linalgOp) {
getVectorSizes(linalg::LinalgOp linalgOp, bool useConfiguredVectorSizes) {
// Get vector sizes from the lowering config, if available in the op itself.
IREE::Codegen::LoweringConfigAttr loweringConfig =
getLoweringConfig(linalgOp);
if (loweringConfig) {
if (useConfiguredVectorSizes && loweringConfig) {
TilingConfig tilingConfig(loweringConfig);
auto [vectorSizes, scalableFlags] = tilingConfig.getVectorTileSizes();
// Replace zeros in canonical vector shape to turn it into a valid shape.
Expand Down Expand Up @@ -128,6 +128,7 @@ class GenericVectorizationPass
using GenericVectorizationBase::GenericVectorizationBase;
GenericVectorizationPass(const GenericVectorizationPassOptions &options) {
this->enableVectorMasking.setValue(options.enableVectorMasking);
this->useConfiguredVectorSizes.setValue(options.useConfiguredVectorSizes);
this->vectorizePadding.setValue(options.vectorizePadding);
this->vectorizeGatherAccesses.setValue(options.vectorizeGatherAccesses);
this->enableCleanup.setValue(options.enableCleanup);
Expand Down Expand Up @@ -162,7 +163,7 @@ void GenericVectorizationPass::runOnOperation() {
// Do not vectorize the op if the vector size is greater than or equal
// to limit.
if (enableVectorMasking) {
auto vectorSizesAndScalableDims = getVectorSizes(linalgOp);
auto vectorSizesAndScalableDims = getVectorSizes(linalgOp, useConfiguredVectorSizes);
if (succeeded(vectorSizesAndScalableDims)) {
auto [sizes, scalableDims] = *vectorSizesAndScalableDims;
vectorSizes.append(sizes.begin(), sizes.end());
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ createFuseTensorPadWithConsumerPass();

struct GenericVectorizationPassOptions {
bool enableVectorMasking = false;
// Controls whether the op lowering configuration (if present) should be used
// to specify the masked vector sizes.
bool useConfiguredVectorSizes = true;
bool vectorizePadding = false;
bool vectorizeGatherAccesses = false;
// The flag controls whether it touches the structure generated from tiling,
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ def GenericVectorization :
let options = [
Option<"enableVectorMasking", "enable-vector-masking", "bool",/*default=*/"false",
"Enable vector masking during vectorization.">,
Option<"useConfiguredVectorSizes", "use-configured-vector-sizes", "bool",/*default=*/"true",
"Control whether the op lowering config represents a set of masked vector sizes">,
Option<"vectorizePadding", "vectorize-padding", "bool", /*default=*/"false",
"Rewrite all tensor.pad ops in the function to vector form.">,
Option<"vectorizeGatherAccesses", "vectorize-gather-accesses", "bool", /*default=*/"false",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &pm) {
{
GenericVectorizationPassOptions options;
options.enableVectorMasking = true;
options.useConfiguredVectorSizes = false;
options.vectorizePadding = true;
options.vectorizeGatherAccesses = true;
options.enableCleanup = false;
Expand Down

0 comments on commit 8c727b1

Please sign in to comment.