diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp index b1e3083f57fb..416695d08476 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp @@ -8,17 +8,22 @@ #include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" @@ -28,8 +33,6 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include - #define DEBUG_TYPE "iree-flow-collapse-dimensions" namespace mlir::iree_compiler::IREE::Flow { @@ -46,6 +49,10 @@ struct CollapseDimensionsPass }; } // namespace +//===---------------------------------------------------------------------===// +// Helper functions +//===---------------------------------------------------------------------===// + /// Searches the same sequence in all the affine maps and collapses these /// dimensions. It only applies these to "parallel" loops without mixing them /// with "reduction" types. It is expected that the `genericOp` has projected @@ -105,26 +112,9 @@ getCollapsibleLoops(linalg::GenericOp genericOp) { (rDimsSet.count(prePos) && rDimsSet.count(nextPos)); }; - // Find all dims that are used to iterate over operands that aren't produced - // outside of the dispatch - auto regionOp = cast(genericOp->getParentOp()); - llvm::SmallSet preservedDims; - for (OpOperand *operand : genericOp.getDpsInputOperands()) { - auto definingOp = operand->get().getDefiningOp(); - if (!definingOp || - definingOp->getParentOfType() != regionOp) - continue; - for (AffineExpr expr : - genericOp.getMatchingIndexingMap(operand).getResults()) { - preservedDims.insert(cast(expr).getPosition()); - } - } - ReassociationIndices range; AffineExpr preExpr; // Find the largest sequence of dimensions that are - // - Not used to index operands with defining ops - // AND // - Either preserved in all maps, or // - are completely absent // This sequence can be collapsed. To find the sequence, @@ -138,8 +128,7 @@ getCollapsibleLoops(linalg::GenericOp genericOp) { for (auto nextExpr : genericOp.getIndexingMapsArray().front().getResults()) { unsigned position = cast(nextExpr).getPosition(); if (!range.empty()) { - if (preservedDims.contains(position) || - !hasAllMapsSameSequence(preExpr, nextExpr) || + if (!hasAllMapsSameSequence(preExpr, nextExpr) || !hasSameIteratorType(preExpr, nextExpr)) { if (range.size() > 1) { contiguousLoops.push_back({range.begin(), range.end()}); @@ -153,17 +142,6 @@ getCollapsibleLoops(linalg::GenericOp genericOp) { if (range.size() > 1) contiguousLoops.push_back(range); - LLVM_DEBUG({ - llvm::dbgs() << "Collapsing dimensions if possible: "; - for (auto indices : contiguousLoops) { - llvm::dbgs() << "["; - for (auto idx : indices) - llvm::dbgs() << idx << ","; - llvm::dbgs() << "]\t"; - } - llvm::dbgs() << "\n"; - }); - return contiguousLoops; } @@ -186,30 +164,307 @@ static bool isEligibleForCollapse(linalg::GenericOp genericOp) { return false; } - // TODO(#17948) GPU codegen fails when trying to collapse the - // dimensions of an elementwise op in the case of elementwise(contraction). - // For now, don't collapse when there is a linalgOp producer. - if (llvm::any_of(genericOp.getDpsInputs(), [](Value val) -> bool { - return val.getDefiningOp(); - })) { - return false; - } - // TODO(guray) Collapsing caused performance regression in a cpu // benchmark, so we disable it. if (genericOp.hasIndexSemantics()) return false; + // TODO(#17948) GPU codegen fails when we collapse the dimensions of softmax. + if (llvm::any_of(genericOp.getDpsInputOperands(), + [&](OpOperand *operand) -> bool { + auto genericOperand = + operand->get().getDefiningOp(); + if (!genericOperand) + return false; + + if (genericOperand.getNumReductionLoops() == 0) + return false; + + return genericOp.getMatchingIndexingMap(operand) + .isProjectedPermutation(); + })) { + return false; + } + return true; } -/// Traverses all the the Ops in DispatchRegionOps and finds linalg.generic Op -static FailureOr -findRootGenericOp(DispatchRegionOp regionOp) { - if (!llvm::hasSingleElement(regionOp.getBody())) { +// For the `operand` with producers and consumers of type `genericOp`, get +// of producer loop -> consumer loop. +static FailureOr +getProducerLoopToConsumerLoopsMap(OpOperand &operand) { + linalg::GenericOp consumer = dyn_cast(operand.getOwner()); + if (!consumer) { + return failure(); + } + linalg::GenericOp producer = + dyn_cast_or_null(operand.get().getDefiningOp()); + if (!producer) { + return failure(); + } + + AffineMap consumerOperandMap = consumer.getMatchingIndexingMap(&operand); + if (!consumerOperandMap.isProjectedPermutation()) { + return failure(); + } + + AffineMap producerResultMap = + producer.getIndexingMapMatchingResult(cast(operand.get())); + if (!producerResultMap.isProjectedPermutation()) { + return failure(); + } + + AffineMap inverseProducerResultMap = + inverseAndBroadcastProjectedPermutation(producerResultMap); + if (!inverseProducerResultMap) { + return failure(); + } + + AffineMap producerLoopToConsumerLoop = + inverseProducerResultMap.compose(consumerOperandMap); + return producerLoopToConsumerLoop; +} + +static FailureOr +getConsumerLoopToProducerLoopsMap(OpOperand &operand) { + linalg::GenericOp consumer = dyn_cast(operand.getOwner()); + if (!consumer) { + return failure(); + } + linalg::GenericOp producer = + dyn_cast_or_null(operand.get().getDefiningOp()); + if (!producer) { + return failure(); + } + + AffineMap consumerOperandMap = consumer.getMatchingIndexingMap(&operand); + if (!consumerOperandMap.isProjectedPermutation()) { + return failure(); + } + + AffineMap producerResultMap = + producer.getIndexingMapMatchingResult(cast(operand.get())); + if (!producerResultMap.isProjectedPermutation()) { + return failure(); + } + + AffineMap inverseConsumerOperandMap = + inverseAndBroadcastProjectedPermutation(consumerOperandMap); + if (!inverseConsumerOperandMap) { + return failure(); + } + + AffineMap consumerLoopToProducerLoop = + inverseConsumerOperandMap.compose(producerResultMap); + return consumerLoopToProducerLoop; +} + +//===---------------------------------------------------------------------===// +// CollapseInfo +//===---------------------------------------------------------------------===// + +namespace { +class CollapseInfo { +public: + using CollapsableLoopsSet = llvm::SmallSetVector; + + CollapseInfo() = default; + CollapseInfo(ArrayRef reassociation) + : reassociation(reassociation), + collapsableLoops(getCollapsedFromReassociation(reassociation)) {} + + // Print the current operation & reassociation indicies + void print(raw_ostream &os) const; + + // Debug print the current operation & reassociation indicies + void dump() const; + + // Update `collapsableLoops` by taking the set intersection with + // `otherCollapsable` and update the reassociation indicies accordingly. + void updateCollapseViaIntersect(const CollapsableLoopsSet &otherCollapsable); + + // Update `collapsableLoops` by subtracting `uncollapsable` and update the + // reassociation indicies accordingly. + void updateCollapseViaSubtract(const CollapsableLoopsSet &uncollapsable); + + // Get `collapsableLoops` after applying the transformation provided by `map`. + // Note: doesn't modify `collapsableLoops`, the tranformation is applied to a + // copy. + FailureOr + getTransformedCollapsableLoops(AffineMap map) const; + + // Clear internal data + void clear() { + reassociation.clear(); + collapsableLoops.clear(); + } + + const CollapsableLoopsSet &getCollapsibleLoops() const { + return collapsableLoops; + } + + const SmallVector &getReassocation() const { + return reassociation; + } + +private: + // Get a set of all elements in `reassociation` + static CollapsableLoopsSet + getCollapsedFromReassociation(ArrayRef reassociation) { + CollapsableLoopsSet collapsed; + for (auto &indicies : reassociation) { + for (int64_t index : indicies) { + collapsed.insert(index); + } + } + return collapsed; + } + + // Update `reassociation` by removing indicies that are no longer in + // `collapsableLoops` and spliting the reassociation indicies accordingly + void updateReassociation(); + +private: + // A vector of `ReassociationIndicies` representing contiguous dimensions that + // can be collapsed together. + SmallVector reassociation; + + // Note: `collapsableLoops` does not directly map to `reassociation` + // because parallel and reduction iteration dimensions must be kept separate. + CollapsableLoopsSet collapsableLoops; +}; +} // namespace + +// Removes any indicies in `reassociation` that are not in `collapsableLoops`, +// The reassociation indicies are split along the uncollapsable element because +// the dims aren't contiguous and cannot be collapsed. Single element +// reassociation indicies are cleaned up. +void CollapseInfo::updateReassociation() { + SmallVector newReassociation; + for (auto &indicies : reassociation) { + + // Holds dimensions that should be collapsed together + ReassociationIndices newIndicies; + for (int64_t index : indicies) { + // This index is collapsable and should be kept in the reassociation + // indicies. + if (collapsableLoops.contains(index)) { + newIndicies.push_back(index); + continue; + } + + // Because `index` isn't collapsable, the indicies in `newIndicies` are no + // longer adjacent to the upcoming indicies. If there is >1 index to + // collapse, add it to the new reassociation. Otherwise, discard it + // because there is no dimension to collapse with. + if (newIndicies.size() > 1) { + newReassociation.push_back(newIndicies); + } + newIndicies.clear(); + } + + if (newIndicies.size() > 1) { + newReassociation.push_back(newIndicies); + } + } + reassociation = std::move(newReassociation); +} + +// Given an AffineMap `map` get the transformed `collapsableLoops`. For example, +// if this `CollapseInfo` represents a elementwise linalg generic operating on a +// 3d tensor (so its collapsableLoops might be {0, 1, 2}), the map would be used +// to map the loops to the iteration space of its producer or consumer. +// +// Consider it's consumer accesses the result of said operation with +// affine_map<(d0, d1, d2) -> (d1, d2, d5)> +// +// Then: +// collapsableLoops = {0, 1, 2} +// map = affine_map<(d0, d1, d2) -> (d1, d2, d5)> +// +// Therefore, the collapsable loops with respect to the consumer is {1, 2, 5}. +FailureOr +CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const { + if (!map) { return failure(); } + CollapsableLoopsSet transformedLoops; + for (auto index : collapsableLoops) { + assert(index < map.getNumResults() && "index has no valid mapping"); + auto dimExpr = dyn_cast(map.getResult(index)); + if (!dimExpr) { + continue; + } + + transformedLoops.insert(dimExpr.getPosition()); + } + return transformedLoops; +} + +// Update `collapsableLoops` by taking the set intersection with +// `otherCollapsable` and update the reassociation indicies accordingly. +void CollapseInfo::updateCollapseViaIntersect( + const CollapsableLoopsSet &otherCollapsable) { + + CollapsableLoopsSet toRemove; + for (auto elem : collapsableLoops) { + if (!otherCollapsable.contains(elem)) { + toRemove.insert(elem); + } + } + collapsableLoops.set_subtract(toRemove); + updateReassociation(); +} + +// Update `collapsableLoops` by subtracting `uncollapsable` and update the +// reassociation indicies accordingly. +void CollapseInfo::updateCollapseViaSubtract( + const CollapsableLoopsSet &uncollapsable) { + + collapsableLoops.set_subtract(uncollapsable); + updateReassociation(); +} + +void CollapseInfo::print(raw_ostream &os) const { + os << "[CollapseDimensions] CollapseInfo:\n"; + + os << "Reassociation: "; + os << "["; + for (auto &vec : reassociation) { + os << "["; + bool first = true; + for (auto elem : vec) { + if (!first) { + os << ", "; + } + first = false; + os << elem; + } + os << "]"; + } + os << "]"; + os << "\n"; + + os << "Collapsable: "; + os << "{"; + bool first = true; + for (auto elem : collapsableLoops) { + if (!first) { + os << ", "; + } + first = false; + os << elem; + } + os << "}"; +} + +void CollapseInfo::dump() const { print(llvm::dbgs()); } + +/// Traverses all the the Ops in DispatchRegionOps and finds a linalg.generic Op +/// which is the sole producer of the flow.return's operand. +static FailureOr +findRootGenericOp(DispatchRegionOp regionOp) { // Check the yielded value is from a single `linalg.generic`. auto returnOp = cast(regionOp.getBody().front().getTerminator()); @@ -227,49 +482,19 @@ findRootGenericOp(DispatchRegionOp regionOp) { } } - // Check that the output is either a `tensor.empty` or a `linalg.fill` op by - // traversing the operations that define the `init` operands of the - // `collapsibleOp`. - std::deque worklist; - llvm::SmallDenseSet visited; - auto addDefiningOpToWorklist = [&](Value v) { - Operation *definingOp = v.getDefiningOp(); - if (definingOp && - definingOp->getParentOfType() == regionOp && - !visited.count(definingOp)) { - worklist.push_back(definingOp); - visited.insert(definingOp); - } - }; - for (Value initOperand : collapsibleOp.getDpsInits()) { - addDefiningOpToWorklist(initOperand); - } - - while (!worklist.empty()) { - Operation *op = worklist.front(); - worklist.pop_front(); - if (auto fillOp = dyn_cast(op)) { - addDefiningOpToWorklist(fillOp.getDpsInitOperand(0)->get()); - continue; - } - if (isa(op)) { - continue; - } - return failure(); - } return collapsibleOp; } +//===---------------------------------------------------------------------===// +// Reshape Hoisting +//===---------------------------------------------------------------------===// + /// Hoist `tensor.collapse_shape` ops at the beginning of the `dispatchOp` /// and `tensor.expand_shape` ops at the end of the `dispatchOp`, out of the /// dispatch. static FailureOr hoistTensorReshapesOutOfDispatchRegion(RewriterBase &rewriter, DispatchRegionOp dispatchOp) { - // Only do this for `dispatchOp` with a single operation. - if (!llvm::hasSingleElement(dispatchOp.getBody())) { - return failure(); - } Block &body = dispatchOp.getBody().front(); auto returnOp = cast(body.getTerminator()); @@ -416,36 +641,205 @@ hoistTensorReshapesOutOfDispatchRegion(RewriterBase &rewriter, return newDispatchOp; } -/// Traverses DispatchRegionOps to find linalg genericOps and collapses -/// dimensions without modifying operands with producers -static bool collapseDimensions(IRRewriter &rewriter, - DispatchRegionOp ®ionOp) { - // Step 1. Find the root linalg.generic Op - std::optional genericOp = findRootGenericOp(regionOp); - if (!genericOp.has_value()) - return false; +//===---------------------------------------------------------------------===// +// Collapse shape propagation +//===---------------------------------------------------------------------===// + +// For each consumer, use it's producers to constrain which dimensions it will +// collapse. `slice` is expected to be topologically sorted (getBackwardSlice +// does this automatically). +static void updateConsumersFromProducers( + ArrayRef slice, + llvm::MapVector &opMap) { + + // Slice is topologically sorted to ensure that `op`'s producers have been + // updated before we visit it. + for (auto op : slice) { + auto genericOp = cast(op); + CollapseInfo &consumerInfo = opMap[genericOp]; + + for (auto operand : genericOp.getDpsInputOperands()) { + auto definingOp = operand->get().getDefiningOp(); + if (!definingOp || isNonNullAndOutsideDispatch(definingOp)) { + continue; + } + + // Track the dimensions that are not collapsable by this current op. + // Initialize this with all loops in thel producer. Note: the dims are + // relative to the consumers iteration space, not the producers. This + // cannot be done via union of producer and consumer collapsable loops + // because the consumer may have loops that the producer does not. + CollapseInfo::CollapsableLoopsSet producerUncollapsable; + for (auto expr : genericOp.getMatchingIndexingMap(operand).getResults()) { + producerUncollapsable.insert(cast(expr).getPosition()); + } + + auto producerOp = dyn_cast(definingOp); + FailureOr mapping = + getProducerLoopToConsumerLoopsMap(*operand); + + // If the producer is not a generic or there is no mapping, the tensor is + // not collapsable. So, all dimensions of the producer are uncollapsable. + if (!producerOp || failed(mapping)) { + consumerInfo.updateCollapseViaSubtract(producerUncollapsable); + continue; + } + + CollapseInfo &producerInfo = opMap[producerOp]; + FailureOr producerCollapsable = + producerInfo.getTransformedCollapsableLoops(mapping.value()); + if (!failed(producerCollapsable)) { + producerUncollapsable.set_subtract(producerCollapsable.value()); + } + + consumerInfo.updateCollapseViaSubtract(producerUncollapsable); + } + } +} - // Step 2. Check whether it is possible to collapse - if (!isEligibleForCollapse(genericOp.value())) +// For each producer, use it's consumers to constrain which dimensions it will +// collapse. `slice` is expected to be topologically sorted (getBackwardSlice +// does this automatically). +static void updateProducersFromConsumers( + ArrayRef slice, + llvm::MapVector &opMap) { + + // Iterate over `slice` in reverse so that we visit each `op` 's consumer + // before visiting `op`. + for (auto op : llvm::reverse(slice)) { + auto genericConsumer = cast(op); + const CollapseInfo &consumerInfo = opMap[genericConsumer]; + for (auto operand : genericConsumer.getDpsInputOperands()) { + auto definingOp = operand->get().getDefiningOp(); + if (!definingOp || isNonNullAndOutsideDispatch(definingOp)) { + continue; + } + auto genericProducer = dyn_cast(definingOp); + if (!genericProducer) { + continue; + } + + // Get a mapping from the consumer's iteration space to the producer's. + CollapseInfo &producerInfo = opMap[genericProducer]; + FailureOr consumerToProducerMap = + getConsumerLoopToProducerLoopsMap(*operand); + if (failed(consumerToProducerMap)) { + producerInfo.clear(); + continue; + } + + // Use the map to get the consumer's collapsable loops in terms of the + // producer. + auto consumerCollapsable = consumerInfo.getTransformedCollapsableLoops( + consumerToProducerMap.value()); + if (failed(consumerCollapsable)) { + producerInfo.clear(); + continue; + } + // Only loops collapsable in both the consumer and producer may be + // collapsed. + producerInfo.updateCollapseViaIntersect(consumerCollapsable.value()); + } + } +} + +// Construct a DAG of `linalg.generic` operations with 1 root op. Find +// dimensions that can be collapsed all the way from the root to the leaves, +// ensuring that all `collapse_shape` ops can be hoisted out of the dispatch. +static bool collapseDimensionsForDispatch(IRRewriter &rewriter, + DispatchRegionOp ®ionOp) { + // Only collapse dispatches with 1 block + if (!llvm::hasSingleElement(regionOp.getBody())) { return false; - SmallVector collapseIndices; - collapseIndices = getCollapsibleLoops(genericOp.value()); - if (collapseIndices.empty()) + } + // Step 1. Find the root linalg.generic Op + std::optional rootGenericOp = findRootGenericOp(regionOp); + if (!rootGenericOp.has_value()) return false; - // Step 3. Collapse dimensions - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(genericOp.value()); + // Step 2. Get slice of all linalg.generic ops in the dispatch + BackwardSliceOptions sliceOptions; + sliceOptions.inclusive = true; + sliceOptions.omitBlockArguments = true; + sliceOptions.filter = [&](Operation *op) -> bool { + auto genericOp = dyn_cast(op); + auto parentOp = op->getParentOfType(); + return genericOp && isEligibleForCollapse(genericOp) && + parentOp == regionOp; + }; + SetVector slice; + getBackwardSlice(rootGenericOp->getOperation(), &slice, sliceOptions); + + // Step 3. Populate each op's info with a maximally collapsable reassociation + // indicies + llvm::MapVector opMap; + for (auto *op : slice) { + auto genericOp = cast(op); + opMap[genericOp] = CollapseInfo(getCollapsibleLoops(genericOp)); + } - FailureOr maybeReplacements = - mlir::linalg::collapseOpIterationDims(genericOp.value(), collapseIndices, - rewriter); - if (failed(maybeReplacements)) - return false; - rewriter.replaceOp(genericOp.value(), maybeReplacements->results); - return true; + LLVM_DEBUG({ + llvm::dbgs() << "[CollapseDims] : After initializing opMap\n"; + for (auto &[op, info] : opMap) { + info.dump(); + llvm::dbgs() << "\n"; + op.dump(); + llvm::dbgs() << "\n"; + } + llvm::dbgs() << "\n"; + }); + + // Step 4. For each producer, reduce the number of collapsed dimensions + // based on the dimensions that it's consumers can collapse. + updateProducersFromConsumers(slice.getArrayRef(), opMap); + + LLVM_DEBUG({ + llvm::dbgs() << "[CollapseDims] : After updating producers: \n"; + for (auto &[op, info] : opMap) { + info.dump(); + llvm::dbgs() << "\n"; + op.dump(); + llvm::dbgs() << "\n"; + } + llvm::dbgs() << "\n"; + }); + + // Step 5. For each consumer, update it's CollapseInfo to only collapse + // dimensions that all of its producers can collapse. This ensures that all + // reshapes can be propagated to leafs and be hoisted out of the dispatch. + updateConsumersFromProducers(slice.getArrayRef(), opMap); + + LLVM_DEBUG({ + llvm::dbgs() << "[CollapseDims] : After updating consumers: \n"; + for (auto &[op, info] : opMap) { + info.dump(); + llvm::dbgs() << "\n"; + op.dump(); + llvm::dbgs() << "\n"; + } + llvm::dbgs() << "\n"; + }); + bool didCollapse = false; + + // Step 6. Collapse dimensions based on each op's CollapseInfo + for (auto &[genericOp, info] : opMap) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(genericOp); + FailureOr maybeReplacements = + mlir::linalg::collapseOpIterationDims(genericOp, info.getReassocation(), + rewriter); + if (failed(maybeReplacements)) + continue; + didCollapse = true; + rewriter.replaceOp(genericOp, maybeReplacements->results); + } + return didCollapse; } +//===---------------------------------------------------------------------===// +// Passes +//===---------------------------------------------------------------------===// + void CollapseDimensionsPass::runOnOperation() { mlir::FunctionOpInterface funcOp = getOperation(); MLIRContext *context = funcOp->getContext(); @@ -453,7 +847,7 @@ void CollapseDimensionsPass::runOnOperation() { SmallVector modifiedDispatchOps; funcOp->walk([&](DispatchRegionOp dispatchOp) { - if (collapseDimensions(rewriter, dispatchOp)) { + if (collapseDimensionsForDispatch(rewriter, dispatchOp)) { modifiedDispatchOps.push_back(dispatchOp); } }); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_dimensions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_dimensions.mlir index 211c9a078ece..8f34e796bb89 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_dimensions.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_dimensions.mlir @@ -27,6 +27,7 @@ util.func public @do_not_collapse_cst_in_place(%arg0: tensor<1x1x2304xf32>) { // ----- + #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d1)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> @@ -39,8 +40,8 @@ util.func public @unpack_collapse(%arg0: tensor<2x320x128x128xf32>, %arg1: tenso indexing_maps = [#map, #map1, #map2, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"] } - ins(%arg0, %arg1, %unpack, %arg2 : tensor<2x320x128x128xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>) - outs(%1 : tensor<2x320x128x128xf16>) { + ins(%arg0, %arg1, %unpack, %arg2 : tensor<2x320x128x128xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>) + outs(%1 : tensor<2x320x128x128xf16>) { ^bb0(%in: f32, %in_0: f32, %in_1: f32, %in_2: f32, %out: f16): %3 = arith.addf %in_1, %in_2 : f32 %4 = arith.addf %in, %in_0 : f32 @@ -55,5 +56,393 @@ util.func public @unpack_collapse(%arg0: tensor<2x320x128x128xf32>, %arg1: tenso } // CHECK-LABEL: util.func public @unpack_collapse +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x320x128x128xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] +// CHECK: flow.dispatch.region +// CHECK: %[[GEN:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[COLLAPSED]], {{.*}} : tensor<2x320x16384xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<2x320x16384xf16>) +// CHECK: flow.return %[[GEN]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d1)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +util.func public @unpack_elementwise_collapse(%arg0: tensor<2x320x128x128xf32>, %arg1: tensor<320xf32>, %arg2: tensor<320xf32>, %arg3: tensor<1x5x2x64xf32>) -> tensor<2x320x128x128xf16> { + %0 = flow.dispatch.region -> (tensor<2x320x128x128xf16>) { + %1 = tensor.empty() : tensor<2x320xf32> + %2 = tensor.empty() : tensor<2x320x128x128xf16> + %empty = tensor.empty() : tensor<2x320x128x128xf32> + %cst = arith.constant 3.14 : f32 + + %elementwise = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + + %unpack = tensor.unpack %arg3 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 64] into %1 : tensor<1x5x2x64xf32> -> tensor<2x320xf32> + + %3 = linalg.generic {indexing_maps = [#map, #map1, #map2, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise, %arg1, %unpack, %arg2 : tensor<2x320x128x128xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>) outs(%2 : tensor<2x320x128x128xf16>) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %in_2: f32, %out: f16): + %4 = arith.addf %in_1, %in_2 : f32 + %5 = arith.addf %in, %in_0 : f32 + %6 = arith.truncf %4 : f32 to f16 + %7 = arith.truncf %5 : f32 to f16 + %8 = arith.addf %7, %6 : f16 + linalg.yield %8 : f16 + } -> tensor<2x320x128x128xf16> + flow.return %3 : tensor<2x320x128x128xf16> + } + util.return %0 : tensor<2x320x128x128xf16> +} + +// CHECK-LABEL: util.func public @unpack_elementwise_collapse +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x320x128x128xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] +// CHECK: flow.dispatch.region +// CHECK: %[[ELEMENTWISE:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[COLLAPSED]] : tensor<2x320x16384xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<2x320x16384xf32>) // CHECK: %[[GEN:.+]] = linalg.generic -// CHECK-SAME: tensor<2x320x16384xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32> +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins({{.*}} : tensor<2x320x16384xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<2x320x16384xf16>) +// CHECK: flow.return %[[GEN]] + + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d1)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +util.func public @prevent_collapse(%arg0: tensor<2x320x128x128xf32>, %arg1: tensor<320xf32>, %arg2: tensor<320xf32>, %arg3: tensor<1x5x2x64xf32>) -> tensor<2x320x128x128xf16> { + %0 = flow.dispatch.region -> (tensor<2x320x128x128xf16>) { + %1 = tensor.empty() : tensor<2x320xf32> + %2 = tensor.empty() : tensor<2x320x128x128xf16> + %empty = tensor.empty() : tensor<2x320x128x128xf32> + %cst = arith.constant 3.14 : f32 + + %elementwise = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + + %barrier = util.optimization_barrier %elementwise : tensor<2x320x128x128xf32> + %unpack = tensor.unpack %arg3 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 64] into %1 : tensor<1x5x2x64xf32> -> tensor<2x320xf32> + + %3 = linalg.generic {indexing_maps = [#map, #map1, #map2, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%barrier, %arg1, %unpack, %arg2 : tensor<2x320x128x128xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>) outs(%2 : tensor<2x320x128x128xf16>) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %in_2: f32, %out: f16): + %4 = arith.addf %in_1, %in_2 : f32 + %5 = arith.addf %in, %in_0 : f32 + %6 = arith.truncf %4 : f32 to f16 + %7 = arith.truncf %5 : f32 to f16 + %8 = arith.addf %7, %6 : f16 + linalg.yield %8 : f16 + } -> tensor<2x320x128x128xf16> + flow.return %3 : tensor<2x320x128x128xf16> + } + util.return %0 : tensor<2x320x128x128xf16> +} + +// CHECK-LABEL: util.func public @prevent_collapse +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x320x128x128xf32> +// CHECK: %[[ELEMENTWISE:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]] : tensor<2x320x128x128xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<2x320x128x128xf32>) +// CHECK: %[[GEN:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins({{.*}} : tensor<2x320x128x128xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<2x320x128x128xf16>) + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +util.func public @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> { + %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> + %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> + %0 = flow.dispatch.region -> (tensor<1x1x4096xf32>) { + %cst_1 = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : tensor<1x1x4096xf32> + %2 = tensor.empty() : tensor<4096x32x128xf32> + %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32> + %4 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%2 : tensor<4096x32x128xf32>) { + ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32): + %6 = arith.extui %in : i8 to i32 + %7 = arith.uitofp %6 : i32 to f32 + %8 = arith.subf %7, %in_3 : f32 + %9 = arith.mulf %8, %in_2 : f32 + linalg.yield %9 : f32 + } -> tensor<4096x32x128xf32> + %5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %4 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%3 : tensor<1x1x4096xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %6 = arith.mulf %in, %in_2 : f32 + %7 = arith.addf %6, %out : f32 + linalg.yield %7 : f32 + } -> tensor<1x1x4096xf32> + flow.return %5 : tensor<1x1x4096xf32> + } + util.return %0 : tensor<1x1x4096xf32> +} + +// CHECK-LABEL: util.func public @quantized_matmul +// CHECK-SAME: %[[ARG0:.*]]: tensor<4096x32x128xi8> +// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x32x128xf32> +// CHECK: %[[CST:.*]] = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> +// CHECK: %[[CST_0:.*]] = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG1]] +// CHECK: flow.dispatch.region +// CHECK: %[[VAL0:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[CST]], %[[CST_0]] : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<4096x32x128xf32>) +// CHECK: %[[VAL2:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction"] +// CHECK: ins(%[[COLLAPSED]], %[[VAL0]] : tensor<1x32x128xf32>, tensor<4096x32x128xf32>) +// CHECK: outs(%{{.*}} : tensor<1x4096xf32>) +// CHECK: flow.return + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +util.func public @elementwise_chain(%arg0: tensor<2x320x128x128xf32>) -> tensor<2x320x128x128xf32> { + %0 = flow.dispatch.region -> (tensor<2x320x128x128xf32>) { + %empty = tensor.empty() : tensor<2x320x128x128xf32> + %cst = arith.constant 3.14 : f32 + + %elementwise1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + %elementwise2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise1 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + %elementwise3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise2 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + %elementwise4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise3 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + + flow.return %elementwise4 : tensor<2x320x128x128xf32> + } + util.return %0 : tensor<2x320x128x128xf32> +} + +// CHECK-LABEL: util.func public @elementwise_chain +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape +// CHECK: flow.dispatch.region +// CHECK: %[[VAL0:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[COLLAPSED]] : tensor<10485760xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<10485760xf32>) +// CHECK: %[[VAL1:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[VAL0]] : tensor<10485760xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<10485760xf32>) +// CHECK: %[[VAL2:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[VAL1]] : tensor<10485760xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<10485760xf32>) +// CHECK: %[[VAL3:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[VAL2]] : tensor<10485760xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<10485760xf32>) +// CHECK: flow.return %[[VAL3]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +util.func public @elementwise_dag(%arg0: tensor<2x320x128x128xf32>) -> tensor<2x320x128x128xf32> { + %0 = flow.dispatch.region -> (tensor<2x320x128x128xf32>) { + %empty = tensor.empty() : tensor<2x320x128x128xf32> + %cst = arith.constant 3.14 : f32 + + %elementwise1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + %elementwise2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise1 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + %elementwise3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise1 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + %elementwise4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise3, %elementwise2 : tensor<2x320x128x128xf32>, tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %in_1 : f32, %out : f32): + %22 = arith.mulf %in_1, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + + flow.return %elementwise4 : tensor<2x320x128x128xf32> + } + util.return %0 : tensor<2x320x128x128xf32> +} + +// CHECK-LABEL: util.func public @elementwise_dag +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape +// CHECK: flow.dispatch.region +// CHECK: %[[VAL0:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[COLLAPSED]] : tensor<10485760xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<10485760xf32>) +// CHECK: %[[VAL1:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[VAL0]] : tensor<10485760xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<10485760xf32>) +// CHECK: %[[VAL2:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[VAL0]] : tensor<10485760xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<10485760xf32>) +// CHECK: %[[VAL3:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[VAL2]], %[[VAL1]] : tensor<10485760xf32>, tensor<10485760xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<10485760xf32>) +// CHECK: flow.return %[[VAL3]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)> +util.func public @elementwise_dag_transpose(%arg0: tensor<2x320x128x128xf32>) -> tensor<2x320x128x128xf32> { + %0 = flow.dispatch.region -> (tensor<2x320x128x128xf32>) { + %empty = tensor.empty() : tensor<2x320x128x128xf32> + %cst = arith.constant 3.14 : f32 + + // Check that reducing dims propagates more than 1 op away + %elementwise0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + %elementwise1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise0: tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + %elementwise2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise1 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + %elementwise3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise1 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + %elementwise4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise3, %elementwise2 : tensor<2x320x128x128xf32>, tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %in_1 : f32, %out : f32): + %22 = arith.mulf %in_1, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + + // Check that reducing dims propagates more than 1 op away + %elementwise5 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%elementwise4 : tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) { + ^bb0(%in : f32, %out : f32): + %22 = arith.mulf %cst, %in : f32 + linalg.yield %22 : f32 + } -> tensor<2x320x128x128xf32> + + flow.return %elementwise5 : tensor<2x320x128x128xf32> + } + util.return %0 : tensor<2x320x128x128xf32> +} + +// CHECK-LABEL: util.func public @elementwise_dag_transpose +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape +// CHECK: flow.dispatch.region +// CHECK: %[[VAL0:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[COLLAPSED]] : tensor<640x128x128xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<640x128x128xf32>) +// CHECK: %[[VAL1:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[VAL0]] : tensor<640x128x128xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<640x128x128xf32>) +// CHECK: %[[VAL2:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[VAL1]] : tensor<640x128x128xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<640x128x128xf32>) +// CHECK: %[[VAL3:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[VAL1]] : tensor<640x128x128xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<640x128x128xf32>) +// CHECK: %[[VAL4:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[VAL3]], %[[VAL2]] : tensor<640x128x128xf32>, tensor<640x128x128xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<640x128x128xf32>) +// CHECK: %[[VAL5:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[VAL4]] : tensor<640x128x128xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<640x128x128xf32>) +// CHECK: flow.return %[[VAL5]] + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + +util.func public @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> { + %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> + %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> + %0 = flow.dispatch.region -> (tensor<1x1x4096xf32>) { + %cst_1 = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : tensor<1x1x4096xf32> + %2 = tensor.empty() : tensor<4096x32x128xf32> + %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32> + %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0: tensor<4096x32x128xi8>) outs(%2 : tensor<4096x32x128xf32>) { + ^bb0(%in: i8, %out: f32): + %6 = arith.extui %in : i8 to i32 + %7 = arith.uitofp %6 : i32 to f32 + linalg.yield %7 : f32 + } -> tensor<4096x32x128xf32> + %5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %4 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%3 : tensor<1x1x4096xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %6 = arith.mulf %in, %in_2 : f32 + %7 = arith.addf %6, %out : f32 + linalg.yield %7 : f32 + } -> tensor<1x1x4096xf32> + flow.return %5 : tensor<1x1x4096xf32> + } + util.return %0 : tensor<1x1x4096xf32> +} + +// CHECK-LABEL: util.func public @quantized_matmul +// CHECK-SAME: %[[ARG0:.*]]: tensor<4096x32x128xi8> +// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x32x128xf32> +// CHECK-DAG: %[[COLLAPSED0:.*]] = tensor.collapse_shape %[[ARG0]] +// CHECK-DAG: %[[COLLAPSED1:.*]] = tensor.collapse_shape %[[ARG1]] +// CHECK: flow.dispatch.region +// CHECK: %[[VAL0:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[COLLAPSED0]] : tensor<4096x4096xi8>) +// CHECK-SAME: outs(%{{.*}} : tensor<4096x4096xf32>) +// CHECK: %[[VAL2:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK: ins(%[[COLLAPSED1]], %[[VAL0]] : tensor<1x4096xf32>, tensor<4096x4096xf32>) +// CHECK: outs(%{{.*}} : tensor<1x4096xf32>) +// CHECK: flow.return diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir index c2cb8b6067f0..376f1d82a064 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir @@ -507,51 +507,6 @@ util.func public @input_broadcast(%arg0: tensor<4x8xf32>, %arg1: tensor<4xf32>) // ----- -// Do nothing if the dispatch is not a single elementwise op (with tensor.empty/linalg.fill producers) - -#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> -#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> -#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> -#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> -util.func public @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> { - %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> - %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> - %0 = flow.dispatch.region -> (tensor<1x1x4096xf32>) { - %cst_1 = arith.constant 0.000000e+00 : f32 - %1 = tensor.empty() : tensor<1x1x4096xf32> - %2 = tensor.empty() : tensor<4096x32x128xf32> - %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32> - %4 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%2 : tensor<4096x32x128xf32>) { - ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32): - %6 = arith.extui %in : i8 to i32 - %7 = arith.uitofp %6 : i32 to f32 - %8 = arith.subf %7, %in_3 : f32 - %9 = arith.mulf %8, %in_2 : f32 - linalg.yield %9 : f32 - } -> tensor<4096x32x128xf32> - %5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %4 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%3 : tensor<1x1x4096xf32>) { - ^bb0(%in: f32, %in_2: f32, %out: f32): - %6 = arith.mulf %in, %in_2 : f32 - %7 = arith.addf %6, %out : f32 - linalg.yield %7 : f32 - } -> tensor<1x1x4096xf32> - flow.return %5 : tensor<1x1x4096xf32> - } - util.return %0 : tensor<1x1x4096xf32> -} - -// CHECK-LABEL: util.func public @quantized_matmul -// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -// CHECK: linalg.generic -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK: linalg.generic -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] -// CHECK: flow.return -// CHECK: util.return %[[DISPATCH]] - -// ----- - util.func public @batchnorm_failure_repro(%arg0 : tensor<2x4xf32>, %arg1 : tensor<4xf32>) -> tensor<2x4xf32> { %0 = tensor.empty() : tensor<2x4xf32> %1 = linalg.generic {