diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp index 01025907d1a8..a0f776444241 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp @@ -86,11 +86,11 @@ inferVectorSizesFromIR(linalg::LinalgOp linalgOp) { // Return the vector sizes from the local lowering config or try to infer them // from the tensor shapes and tiled loops in the IR. static FailureOr -getVectorSizes(linalg::LinalgOp linalgOp) { +getVectorSizes(linalg::LinalgOp linalgOp, bool useConfiguredVectorSizes) { // Get vector sizes from the lowering config, if available in the op itself. IREE::Codegen::LoweringConfigAttr loweringConfig = getLoweringConfig(linalgOp); - if (loweringConfig) { + if (useConfiguredVectorSizes && loweringConfig) { TilingConfig tilingConfig(loweringConfig); auto [vectorSizes, scalableFlags] = tilingConfig.getVectorTileSizes(); // Replace zeros in canonical vector shape to turn it into a valid shape. @@ -128,6 +128,7 @@ class GenericVectorizationPass using GenericVectorizationBase::GenericVectorizationBase; GenericVectorizationPass(const GenericVectorizationPassOptions &options) { this->enableVectorMasking.setValue(options.enableVectorMasking); + this->useConfiguredVectorSizes.setValue(options.useConfiguredVectorSizes); this->vectorizePadding.setValue(options.vectorizePadding); this->vectorizeGatherAccesses.setValue(options.vectorizeGatherAccesses); this->enableCleanup.setValue(options.enableCleanup); @@ -162,7 +163,7 @@ void GenericVectorizationPass::runOnOperation() { // Do not vectorize the op if the vector size is greater than or equal // to limit. if (enableVectorMasking) { - auto vectorSizesAndScalableDims = getVectorSizes(linalgOp); + auto vectorSizesAndScalableDims = getVectorSizes(linalgOp, useConfiguredVectorSizes); if (succeeded(vectorSizesAndScalableDims)) { auto [sizes, scalableDims] = *vectorSizesAndScalableDims; vectorSizes.append(sizes.begin(), sizes.end()); diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index 0db187d9b3bc..73d0ad68be68 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -150,6 +150,9 @@ createFuseTensorPadWithConsumerPass(); struct GenericVectorizationPassOptions { bool enableVectorMasking = false; + // Controls whether the op lowering configuration (if present) should be used + // to specify the masked vector sizes. + bool useConfiguredVectorSizes = true; bool vectorizePadding = false; bool vectorizeGatherAccesses = false; // The flag controls whether it touches the structure generated from tiling, diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 4ca69077a602..5327b42ccc63 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -260,6 +260,8 @@ def GenericVectorization : let options = [ Option<"enableVectorMasking", "enable-vector-masking", "bool",/*default=*/"false", "Enable vector masking during vectorization.">, + Option<"useConfiguredVectorSizes", "use-configured-vector-sizes", "bool",/*default=*/"true", + "Control whether the op lowering config represents a set of masked vector sizes">, Option<"vectorizePadding", "vectorize-padding", "bool", /*default=*/"false", "Rewrite all tensor.pad ops in the function to vector form.">, Option<"vectorizeGatherAccesses", "vectorize-gather-accesses", "bool", /*default=*/"false", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index 426dc43f189b..fa2893a846a6 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -546,6 +546,7 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &pm) { { GenericVectorizationPassOptions options; options.enableVectorMasking = true; + options.useConfiguredVectorSizes = false; options.vectorizePadding = true; options.vectorizeGatherAccesses = true; options.enableCleanup = false;