Skip to content

Commit

Permalink
[Dispatch] Two fixes for CollapseDimensionsPass (iree-org#19598)
Browse files Browse the repository at this point in the history
iree-org#19113 uncovered some problems with
the logic in this pass.

Fixes two problems:
1. If a consumer cannot be collapsed, producers can only collapse
dimensions not touched by the consumer
2. When updating which consumer loops can be collapsed, the
reassociation of the producer must be taken into account since its
possible they are not all contiguous (e.g. a transpose on an input).
This is the same logic as in `updateFromConsumer`

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 authored Jan 6, 2025
1 parent 763406f commit cdf24b9
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 73 deletions.
156 changes: 84 additions & 72 deletions compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,19 +314,20 @@ class CollapseInfo {
void dump() const;

// Update CollapseInfo to ensure that all dimensions collapsable in `this` are
// also collapsable in `consumerInfo`. This means:
// 1. Any dimension not collapsable in `consumerInfo` should not be
// also collapsable in `otherInfo`. This means:
// 1. Any dimension not collapsable in `otherInfo` should not be
// collapsable in `this`
// 2. For any pair of dimensions in `this`, if they are collapsable in
// `consumerInfo`, they must be collapsable into the same dimension in
// `consumerInfo` to be collapsable into the same dimension in `this`.
// `otherInfo`, they must be collapsable into the same dimension in
// `otherInfo` to be collapsable into the same dimension in `this`.
// Returns true if the operation modified the number of collapsable loops.
bool updateFromConsumer(OpOperand *operand, const CollapseInfo &consumerInfo);
bool updateFromOther(FailureOr<AffineMap> otherToThisMap,
const CollapseInfo &otherInfo);

// Update `collapsableLoops` by subtracting `uncollapsable` and update the
// reassociation indicies accordingly.
// Returns true if the operation modified the number of collapsable loops.
bool updateCollapseViaSubtract(const CollapsableLoopsSet &uncollapsable);
// Update `this` (which is the info for `op`) when either a producer or
// consumer is not collapsible. This is done by considering all the dims
// accessed by other to be uncollapsible.
bool updateFromUncollapsible(Operation *op, OpOperand *operand);

// Get `collapsableLoops` after applying the transformation provided by `map`.
// Note: doesn't modify `collapsableLoops`, the tranformation is applied to a
Expand Down Expand Up @@ -460,48 +461,56 @@ CollapseInfo::getTransformedReassociation(AffineMap map) const {
return transformedReassociation;
}

bool CollapseInfo::updateFromConsumer(OpOperand *operand,
const CollapseInfo &consumerInfo) {
FailureOr<AffineMap> consumerToProducerMap =
getConsumerLoopToProducerLoopsMap(*operand);
if (failed(consumerToProducerMap)) {
bool CollapseInfo::updateFromOther(FailureOr<AffineMap> otherToThisMap,
const CollapseInfo &otherInfo) {
if (failed(otherToThisMap)) {
return this->clear();
}

CollapsableLoopsSet consumerCollapsable =
consumerInfo.getTransformedCollapsableLoops(
consumerToProducerMap.value());
CollapsableLoopsSet otherCollapsible =
otherInfo.getTransformedCollapsableLoops(otherToThisMap.value());

SmallVector<ReassociationIndices> consumerReassoc =
consumerInfo.getTransformedReassociation(consumerToProducerMap.value());
SmallVector<ReassociationIndices> otherReassoc =
otherInfo.getTransformedReassociation(otherToThisMap.value());

// Get a map from original index to the index it gets collapsed into
llvm::DenseMap<long, long> consumerCollapseMap;
for (const auto &[idx, indicies] : llvm::enumerate(consumerReassoc)) {
llvm::DenseMap<long, long> otherCollapseMap;
for (const auto &[idx, indicies] : llvm::enumerate(otherReassoc)) {
for (const auto elem : indicies) {
consumerCollapseMap[elem] = idx;
otherCollapseMap[elem] = idx;
}
}

// Remove all collapsable loops in `producer` that are not collapsable in
// `consumer` (set intersect)
bool didChange = collapsableLoops.remove_if(
[&](long elem) -> bool { return !consumerCollapsable.contains(elem); });
// Remove all collapsable loops in `this` that both exist and are not
// collapsable in `other` (set intersect)
bool didChange = collapsableLoops.remove_if([&](long elem) -> bool {
// Exists and is collapsable
if (otherCollapsible.contains(elem)) {
return false;
}

// Does not exist in `other`.
if (!otherToThisMap->isFunctionOfDim(elem)) {
return false;
}

return true;
});

// Now update the reassociation indicies given the updated `collapsableLoops`
// and `consumerCollapsableMap`.
// and `otherCollapsableMap`.
// The idea is to reconstruct the reassociation indicies, and at each index:
// (1) If `index` IS NOT in `collapsableLoops`, split `indicies` and don't add
// `index` to either.
//
// (2) If `index` IS in `collapsableLoops` but `consumerCollapseMap` maps
// (2) If `index` IS in `collapsableLoops` but `otherCollapseMap` maps
// `index` to a different collapsed loop then the other indicies, split
// `indicies` and insert `index` into the new one.
//
// For example:
// producer reassociation = [[0, 1], [2, 3]]
// consumer reassociation = [0, 1, 2, 3]
// then, consumer reassociation gets updated to [[0, 1], [2, 3]] because
// `this` reassociation = [[0, 1], [2, 3]]
// `other` reassociation = [0, 1, 2, 3]
// then, `other` reassociation gets updated to [[0, 1], [2, 3]] because
// [0, 1] and [2, 3] get collapsed into different loops
//
// (3) Otherwise, keep the index
Expand All @@ -525,22 +534,25 @@ bool CollapseInfo::updateFromConsumer(OpOperand *operand,
}
newIndicies.clear();
collapseIntoIdx = kUninitialized;
} else if (!otherCollapseMap.contains(index)) {
// (2) `index` does not exist in `other`.
newIndicies.push_back(index);
} else if (collapseIntoIdx == kUninitialized) {
// (2) First occurance of collapsable loop, set collapseIntoIdx.
collapseIntoIdx = consumerCollapseMap.at(index);
// (3) First occurance of collapsable loop, set collapseIntoIdx.
collapseIntoIdx = otherCollapseMap.at(index);
newIndicies.push_back(index);
} else if (consumerCollapseMap.at(index) != collapseIntoIdx) {
// (3) `index` is collapsable but not collapsable into the other loops.
} else if (otherCollapseMap.at(index) != collapseIntoIdx) {
// (4) `index` is collapsable but not collapsable into the other loops.
// So, split them and look for other loops to collapse `index` into.
didChange = true;
if (newIndicies.size() > 1) {
newReassociation.push_back(std::move(newIndicies));
}
newIndicies.clear();
collapseIntoIdx = consumerCollapseMap[index];
collapseIntoIdx = otherCollapseMap[index];
newIndicies.push_back(index);
} else {
// (4) `index` is collapsable and can be collapsed into
// (5) `index` is collapsable and can be collapsed into
// `collapseIntoIndex`.
newIndicies.push_back(index);
}
Expand All @@ -554,10 +566,17 @@ bool CollapseInfo::updateFromConsumer(OpOperand *operand,
return didChange;
}

// Update `collapsableLoops` by subtracting `uncollapsable` and update the
// reassociation indicies accordingly.
bool CollapseInfo::updateCollapseViaSubtract(
const CollapsableLoopsSet &uncollapsable) {
bool CollapseInfo::updateFromUncollapsible(Operation *op, OpOperand *operand) {
auto fusionOp = cast<LinalgFusionOpInterface>(op);
AffineMap map = operand->getOwner() == op
? fusionOp.getMatchingIndexingMap(operand)
: fusionOp.getIndexingMapMatchingResult(
cast<OpResult>(operand->get()));

CollapseInfo::CollapsableLoopsSet uncollapsable;
for (auto expr : map.getResults()) {
uncollapsable.insert(cast<AffineDimExpr>(expr).getPosition());
}
auto initialSize = collapsableLoops.size();
collapsableLoops.set_subtract(uncollapsable);
updateReassociation();
Expand Down Expand Up @@ -791,35 +810,18 @@ updateConsumersFromProducers(ArrayRef<Operation *> slice,
continue;
}

// Track the dimensions that are not collapsable by this current op.
// Initialize this with all loops in thel producer. Note: the dims are
// relative to the consumers iteration space, not the producers. This
// cannot be done via union of producer and consumer collapsable loops
// because the consumer may have loops that the producer does not.
CollapseInfo::CollapsableLoopsSet producerUncollapsable;
for (auto expr :
consumerOp.getMatchingIndexingMap(operand).getResults()) {
producerUncollapsable.insert(cast<AffineDimExpr>(expr).getPosition());
}

FailureOr<AffineMap> mapping =
getProducerLoopToConsumerLoopsMap(*operand);

// If there is no mapping or we can't find the op, the tensor is
// not collapsable. So, all dimensions of the producer are uncollapsable.
if (!opMap.contains(producerOp) || failed(mapping)) {
didChange |=
consumerInfo.updateCollapseViaSubtract(producerUncollapsable);
// If we can't find the op, the tensor is not collapsable. So, consider
// all the dimensions of the producer to be uncollapsable.
if (!opMap.contains(producerOp)) {
didChange |= consumerInfo.updateFromUncollapsible(consumerOp, operand);
continue;
}

const CollapseInfo &producerInfo = opMap.at(producerOp);
CollapseInfo::CollapsableLoopsSet producerCollapsable =
producerInfo.getTransformedCollapsableLoops(mapping.value());
producerUncollapsable.set_subtract(producerCollapsable);

FailureOr<AffineMap> consumerToProducerMap =
getProducerLoopToConsumerLoopsMap(*operand);
didChange |=
consumerInfo.updateCollapseViaSubtract(producerUncollapsable);
consumerInfo.updateFromOther(consumerToProducerMap, producerInfo);
}
}
return didChange;
Expand All @@ -837,21 +839,31 @@ updateProducersFromConsumers(ArrayRef<Operation *> slice,
// Iterate over `slice` in reverse so that we visit each `op` 's consumer
// before visiting `op`.
for (auto op : llvm::reverse(slice)) {
auto consumerOp = cast<DestinationStyleOpInterface>(op);
const CollapseInfo &consumerInfo = opMap.at(consumerOp);
auto producerOp = cast<LinalgFusionOpInterface>(op);
CollapseInfo &producerInfo = opMap.find(producerOp)->second;

for (auto *operand : consumerOp.getDpsInputOperands()) {
auto definingOp = operand->get().getDefiningOp();
if (!definingOp || !opMap.contains(definingOp)) {
for (auto &operand : producerOp->getUses()) {
auto *consumerOp = operand.getOwner();
if (consumerOp->hasTrait<OpTrait::IsTerminator>()) {
continue;
}

// If we can't find the op, the tensor is not collapsable. So, consider
// all the dimensions of the consumer to be uncollapsable.
if (!opMap.contains(consumerOp)) {
didChange |= producerInfo.updateFromUncollapsible(producerOp, &operand);
continue;
}

// Get a mapping from the consumer's iteration space to the producer's.
CollapseInfo &producerInfo = opMap.find(definingOp)->second;
const CollapseInfo &consumerInfo = opMap.at(consumerOp);

// Only loops collapsable in both the consumer and producer may be
// collapsed.
didChange |= producerInfo.updateFromConsumer(operand, consumerInfo);
FailureOr<AffineMap> consumerToProducerMap =
getConsumerLoopToProducerLoopsMap(operand);
didChange |=
producerInfo.updateFromOther(consumerToProducerMap, consumerInfo);
}
}
return didChange;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ util.func public @do_not_collapse_cst_in_place(%arg0: tensor<1x1x2304xf32>) {
util.return
}
// CHECK-LABEL: util.func public @do_not_collapse_cst_in_place
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]]]
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK-DAG: %[[CST:.+]] = arith.constant
// CHECK-DAG: %[[COLLAPSED_ARG0:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-DAG: %[[COLLAPSED_CST:.+]] = tensor.collapse_shape %[[CST]]
Expand Down Expand Up @@ -656,3 +656,100 @@ util.func public @collapse(%10: tensor<2x32x32x1280xi8>, %11 : tensor<10240x1280
// CHECK: %[[GEN1:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK: flow.return %[[GEN1]] : tensor<2048x10240xf16>

// -----

util.func public @update_from_producer(%arg0: tensor<2x1x256x16x16xi8>, %arg1: tensor<2x1x256xf32>) -> tensor<1x256x16x16xi8> {
%cst = arith.constant 0.000000e+00 : f32
%0 = flow.dispatch.region -> (tensor<1x256x16x16xi8>) {
%1 = tensor.empty() : tensor<1x256x16x16xi8>
%2 = tensor.empty() : tensor<1x256x16x16xf32>
%3 = tensor.empty() : tensor<2x1x256x16x16xf32>
%4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x1x256x16x16xi8>) outs(%3 : tensor<2x1x256x16x16xf32>) {
^bb0(%in: i8, %out: f32):
%8 = arith.extsi %in : i8 to i32
%9 = arith.sitofp %8 : i32 to f32
linalg.yield %9 : f32
} -> tensor<2x1x256x16x16xf32>
%5 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1x256x16x16xf32>) -> tensor<1x256x16x16xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%4, %arg1 : tensor<2x1x256x16x16xf32>, tensor<2x1x256xf32>) outs(%5 : tensor<1x256x16x16xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%8 = arith.mulf %in, %in_0 : f32
%9 = arith.addf %8, %out : f32
linalg.yield %9 : f32
} -> tensor<1x256x16x16xf32>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : tensor<1x256x16x16xf32>) outs(%1 : tensor<1x256x16x16xi8>) {
^bb0(%in: f32, %out: i8):
%8 = arith.fptosi %in : f32 to i8
linalg.yield %8 : i8
} -> tensor<1x256x16x16xi8>
flow.return %7 : tensor<1x256x16x16xi8>
}
util.return %0 : tensor<1x256x16x16xi8>
}

// CHECK-LABEL: util.func public @update_from_producer
// CHECK: %[[GEN0:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK: %[[GEN1:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[GEN0]]
// CHECK: %[[GEN2:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[GEN1]]
// CHECK: flow.return %[[GEN2]] : tensor<256x256xi8>

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
util.func public @uncollapsable_consumer(%arg0: tensor<1x1x2304xf32>) {
%cst = arith.constant dense<0.000000e+00> : tensor<1x1x2304xf32>
%0 = tensor.empty() : tensor<1x1x2304xf32>
%1 = flow.dispatch.region -> (tensor<1x1x2304xf32>) {
%2 = tensor.empty() : tensor<1x1x2304xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst : tensor<1x1x2304xf32>, tensor<1x1x2304xf32>) outs(%2 : tensor<1x1x2304xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%4 = arith.addf %in, %in_0 : f32
linalg.yield %4 : f32
} -> tensor<1x1x2304xf32>
%10 = util.optimization_barrier %3 : tensor<1x1x2304xf32>
flow.return %3 : tensor<1x1x2304xf32>
}
util.return
}
// CHECK-LABEL: util.func public @uncollapsable_consumer
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK-DAG: %[[CST:.+]] = arith.constant
// CHECK: %{{.+}} = flow.dispatch.region
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[CST]]
// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[RES]]
// CHECK: flow.return %[[RES]]

// -----

#map0 = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
util.func public @uncollapsable_consumer_partial(%arg0: tensor<10x20x30x2304xf32>) {
%cst = arith.constant dense<0.000000e+00> : tensor<10x20x30x2304xf32>
%0 = tensor.empty() : tensor<30x2304xf32>
%1 = flow.dispatch.region -> (tensor<30x2304xf32>) {
%2 = tensor.empty() : tensor<30x2304xf32>
%3 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %cst : tensor<10x20x30x2304xf32>, tensor<10x20x30x2304xf32>) outs(%2 : tensor<30x2304xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%4 = arith.addf %in, %in_0 : f32
linalg.yield %4 : f32
} -> tensor<30x2304xf32>
%10 = util.optimization_barrier %3 : tensor<30x2304xf32>
flow.return %3 : tensor<30x2304xf32>
}
util.return
}
// CHECK-LABEL: util.func public @uncollapsable_consumer_partial
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK-DAG: %[[CST:.+]] = arith.constant
// CHECK: %{{.+}} = flow.dispatch.region
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[RES]]
// CHECK: flow.return %[[RES]]

0 comments on commit cdf24b9

Please sign in to comment.