Skip to content

Commit

Permalink
Use llvm::filter_to_vector. NFC. (iree-org#19297)
Browse files Browse the repository at this point in the history
I recently added this to SmallVectorExtras:
llvm/llvm-project#117460.
  • Loading branch information
kuhar authored Nov 26, 2024
1 parent ef4ecf3 commit 4e3e898
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,17 @@ using mlir::transform::NamedSequenceOp;
static SmallVector<ModuleOp>
findNestedModulesWithNamedSequences(ModuleOp module) {
Block *body = module.getBody();
return llvm::to_vector(
llvm::make_filter_range(body->getOps<ModuleOp>(), [](ModuleOp op) {
return op.getSymName().has_value() &&
op->hasAttr(
transform::TransformDialect::kWithNamedSequenceAttrName);
}));
return llvm::filter_to_vector(body->getOps<ModuleOp>(), [](ModuleOp op) {
return op.getSymName().has_value() &&
op->hasAttr(transform::TransformDialect::kWithNamedSequenceAttrName);
});
}

static SmallVector<NamedSequenceOp> findTuningSpecs(ModuleOp module) {
Block *body = module.getBody();
return llvm::to_vector(llvm::make_filter_range(
return llvm::filter_to_vector(
body->getOps<NamedSequenceOp>(),
[](NamedSequenceOp op) { return op->hasAttr(kTuningSpecAttrName); }));
[](NamedSequenceOp op) { return op->hasAttr(kTuningSpecAttrName); });
}

static LogicalResult validateTuningSpec(NamedSequenceOp op) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter,
}

auto forAllOps = body.getOps<scf::ForallOp>();
SmallVector<scf::ForallOp> workgroupForAllOps = llvm::to_vector(
llvm::make_filter_range(forAllOps, [&](scf::ForallOp forAllOp) {
SmallVector<scf::ForallOp> workgroupForAllOps =
llvm::filter_to_vector(forAllOps, [&](scf::ForallOp forAllOp) {
auto mapping = forAllOp.getMapping();
if (!mapping) {
return false;
Expand All @@ -277,7 +277,7 @@ static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter,
return false;
}
return true;
}));
});

if (workgroupForAllOps.empty()) {
// If there are no workgroup distribution loops, set the default
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
Expand All @@ -26,10 +27,9 @@ namespace mlir::iree_compiler {
static llvm::SmallVector<unsigned>
pruneUnitTripParallelLoops(llvm::ArrayRef<unsigned> parallelLoops,
llvm::ArrayRef<int64_t> loopRanges) {
return llvm::to_vector(
llvm::make_filter_range(parallelLoops, [&loopRanges](unsigned loopDim) {
return loopRanges[loopDim] != 1;
}));
return llvm::filter_to_vector(parallelLoops, [&loopRanges](unsigned loopDim) {
return loopRanges[loopDim] != 1;
});
}

/// Returns the partitionable loops for all Linalg ops.
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -2784,10 +2785,10 @@ adjustTileSizesForUnPackOp(mlir::FunctionOpInterface entryPointFn,
// Remove the "enable_loop_peeling" attr from pipelineConfig
auto enableLoopPeelingAttrName =
getEnableLoopPeelingAttrName(rootOp->getContext());
auto newPipelineConfigEntries = llvm::to_vector(llvm::make_filter_range(
auto newPipelineConfigEntries = llvm::filter_to_vector(
pipelineConfig.getValue(), [&](NamedAttribute entry) {
return entry.getName() != enableLoopPeelingAttrName;
}));
});

pipelineConfig =
DictionaryAttr::get(rootOp->getContext(), newPipelineConfigEntries);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
Expand Down Expand Up @@ -322,10 +323,10 @@ struct SPIRVMaterializeExecutableConditionsPass final

// Drop the fine-grained SPIR-V target and add the course-grained device
// queries as a list.
auto dictKeyValues = llvm::to_vector(llvm::make_filter_range(
auto dictKeyValues = llvm::filter_to_vector(
configuration.getValue(), [](NamedAttribute attr) {
return attr.getName() != spirv::getTargetEnvAttrName();
}));
});
dictKeyValues.emplace_back(builder.getStringAttr("iree.spirv.features"),
builder.getStrArrayAttr(queries));
variantOp.setTargetAttr(IREE::HAL::ExecutableTargetAttr::get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ static void computeRegionValueAliases(Operation *regionOp,

// Filter out to only resource results - some regions may return additional
// things like stream.async.execute returning a timepoint.
auto resourceResults = llvm::to_vector_of<OpResult>(
llvm::make_filter_range(regionOp->getResults(), [](OpResult result) {
auto resourceResults =
llvm::filter_to_vector(regionOp->getResults(), [](OpResult result) {
return llvm::isa<IREE::Stream::ResourceType>(result.getType());
}));
});

// Start with outputs so that we handle tied values that may lead all the way
// back up the chain to the stream inputs.
Expand Down Expand Up @@ -1145,12 +1145,12 @@ static std::optional<ConstantAllocation>
extractConstantsWithLifetime(IREE::Stream::AsyncExecuteOp executeOp,
IREE::Stream::Lifetime lifetime,
OpBuilder &externalBuilder) {
auto constantOps = llvm::to_vector(llvm::make_filter_range(
auto constantOps = llvm::filter_to_vector(
executeOp.getOps<IREE::Stream::AsyncConstantOp>(),
[&](IREE::Stream::AsyncConstantOp op) {
return cast<IREE::Stream::ResourceType>(op.getResult().getType())
.getLifetime() == lifetime;
}));
});
if (constantOps.empty())
return {};

Expand Down

0 comments on commit 4e3e898

Please sign in to comment.