diff --git a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp index a6cee120e1e2..26a72d08ab47 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp @@ -314,19 +314,20 @@ class CollapseInfo { void dump() const; // Update CollapseInfo to ensure that all dimensions collapsable in `this` are - // also collapsable in `consumerInfo`. This means: - // 1. Any dimension not collapsable in `consumerInfo` should not be + // also collapsable in `otherInfo`. This means: + // 1. Any dimension not collapsable in `otherInfo` should not be // collapsable in `this` // 2. For any pair of dimensions in `this`, if they are collapsable in - // `consumerInfo`, they must be collapsable into the same dimension in - // `consumerInfo` to be collapsable into the same dimension in `this`. + // `otherInfo`, they must be collapsable into the same dimension in + // `otherInfo` to be collapsable into the same dimension in `this`. // Returns true if the operation modified the number of collapsable loops. - bool updateFromConsumer(OpOperand *operand, const CollapseInfo &consumerInfo); + bool updateFromOther(FailureOr otherToThisMap, + const CollapseInfo &otherInfo); - // Update `collapsableLoops` by subtracting `uncollapsable` and update the - // reassociation indicies accordingly. - // Returns true if the operation modified the number of collapsable loops. - bool updateCollapseViaSubtract(const CollapsableLoopsSet &uncollapsable); + // Update `this` (which is the info for `op`) when either a producer or + // consumer is not collapsible. This is done by considering all the dims + // accessed by other to be uncollapsible. + bool updateFromUncollapsible(Operation *op, OpOperand *operand); // Get `collapsableLoops` after applying the transformation provided by `map`. // Note: doesn't modify `collapsableLoops`, the tranformation is applied to a @@ -460,48 +461,56 @@ CollapseInfo::getTransformedReassociation(AffineMap map) const { return transformedReassociation; } -bool CollapseInfo::updateFromConsumer(OpOperand *operand, - const CollapseInfo &consumerInfo) { - FailureOr consumerToProducerMap = - getConsumerLoopToProducerLoopsMap(*operand); - if (failed(consumerToProducerMap)) { +bool CollapseInfo::updateFromOther(FailureOr otherToThisMap, + const CollapseInfo &otherInfo) { + if (failed(otherToThisMap)) { return this->clear(); } - CollapsableLoopsSet consumerCollapsable = - consumerInfo.getTransformedCollapsableLoops( - consumerToProducerMap.value()); + CollapsableLoopsSet otherCollapsible = + otherInfo.getTransformedCollapsableLoops(otherToThisMap.value()); - SmallVector consumerReassoc = - consumerInfo.getTransformedReassociation(consumerToProducerMap.value()); + SmallVector otherReassoc = + otherInfo.getTransformedReassociation(otherToThisMap.value()); // Get a map from original index to the index it gets collapsed into - llvm::DenseMap consumerCollapseMap; - for (const auto &[idx, indicies] : llvm::enumerate(consumerReassoc)) { + llvm::DenseMap otherCollapseMap; + for (const auto &[idx, indicies] : llvm::enumerate(otherReassoc)) { for (const auto elem : indicies) { - consumerCollapseMap[elem] = idx; + otherCollapseMap[elem] = idx; } } - // Remove all collapsable loops in `producer` that are not collapsable in - // `consumer` (set intersect) - bool didChange = collapsableLoops.remove_if( - [&](long elem) -> bool { return !consumerCollapsable.contains(elem); }); + // Remove all collapsable loops in `this` that both exist and are not + // collapsable in `other` (set intersect) + bool didChange = collapsableLoops.remove_if([&](long elem) -> bool { + // Exists and is collapsable + if (otherCollapsible.contains(elem)) { + return false; + } + + // Does not exist in `other`. + if (!otherToThisMap->isFunctionOfDim(elem)) { + return false; + } + + return true; + }); // Now update the reassociation indicies given the updated `collapsableLoops` - // and `consumerCollapsableMap`. + // and `otherCollapsableMap`. // The idea is to reconstruct the reassociation indicies, and at each index: // (1) If `index` IS NOT in `collapsableLoops`, split `indicies` and don't add // `index` to either. // - // (2) If `index` IS in `collapsableLoops` but `consumerCollapseMap` maps + // (2) If `index` IS in `collapsableLoops` but `otherCollapseMap` maps // `index` to a different collapsed loop then the other indicies, split // `indicies` and insert `index` into the new one. // // For example: - // producer reassociation = [[0, 1], [2, 3]] - // consumer reassociation = [0, 1, 2, 3] - // then, consumer reassociation gets updated to [[0, 1], [2, 3]] because + // `this` reassociation = [[0, 1], [2, 3]] + // `other` reassociation = [0, 1, 2, 3] + // then, `other` reassociation gets updated to [[0, 1], [2, 3]] because // [0, 1] and [2, 3] get collapsed into different loops // // (3) Otherwise, keep the index @@ -525,22 +534,25 @@ bool CollapseInfo::updateFromConsumer(OpOperand *operand, } newIndicies.clear(); collapseIntoIdx = kUninitialized; + } else if (!otherCollapseMap.contains(index)) { + // (2) `index` does not exist in `other`. + newIndicies.push_back(index); } else if (collapseIntoIdx == kUninitialized) { - // (2) First occurance of collapsable loop, set collapseIntoIdx. - collapseIntoIdx = consumerCollapseMap.at(index); + // (3) First occurance of collapsable loop, set collapseIntoIdx. + collapseIntoIdx = otherCollapseMap.at(index); newIndicies.push_back(index); - } else if (consumerCollapseMap.at(index) != collapseIntoIdx) { - // (3) `index` is collapsable but not collapsable into the other loops. + } else if (otherCollapseMap.at(index) != collapseIntoIdx) { + // (4) `index` is collapsable but not collapsable into the other loops. // So, split them and look for other loops to collapse `index` into. didChange = true; if (newIndicies.size() > 1) { newReassociation.push_back(std::move(newIndicies)); } newIndicies.clear(); - collapseIntoIdx = consumerCollapseMap[index]; + collapseIntoIdx = otherCollapseMap[index]; newIndicies.push_back(index); } else { - // (4) `index` is collapsable and can be collapsed into + // (5) `index` is collapsable and can be collapsed into // `collapseIntoIndex`. newIndicies.push_back(index); } @@ -554,10 +566,17 @@ bool CollapseInfo::updateFromConsumer(OpOperand *operand, return didChange; } -// Update `collapsableLoops` by subtracting `uncollapsable` and update the -// reassociation indicies accordingly. -bool CollapseInfo::updateCollapseViaSubtract( - const CollapsableLoopsSet &uncollapsable) { +bool CollapseInfo::updateFromUncollapsible(Operation *op, OpOperand *operand) { + auto fusionOp = cast(op); + AffineMap map = operand->getOwner() == op + ? fusionOp.getMatchingIndexingMap(operand) + : fusionOp.getIndexingMapMatchingResult( + cast(operand->get())); + + CollapseInfo::CollapsableLoopsSet uncollapsable; + for (auto expr : map.getResults()) { + uncollapsable.insert(cast(expr).getPosition()); + } auto initialSize = collapsableLoops.size(); collapsableLoops.set_subtract(uncollapsable); updateReassociation(); @@ -791,35 +810,18 @@ updateConsumersFromProducers(ArrayRef slice, continue; } - // Track the dimensions that are not collapsable by this current op. - // Initialize this with all loops in thel producer. Note: the dims are - // relative to the consumers iteration space, not the producers. This - // cannot be done via union of producer and consumer collapsable loops - // because the consumer may have loops that the producer does not. - CollapseInfo::CollapsableLoopsSet producerUncollapsable; - for (auto expr : - consumerOp.getMatchingIndexingMap(operand).getResults()) { - producerUncollapsable.insert(cast(expr).getPosition()); - } - - FailureOr mapping = - getProducerLoopToConsumerLoopsMap(*operand); - - // If there is no mapping or we can't find the op, the tensor is - // not collapsable. So, all dimensions of the producer are uncollapsable. - if (!opMap.contains(producerOp) || failed(mapping)) { - didChange |= - consumerInfo.updateCollapseViaSubtract(producerUncollapsable); + // If we can't find the op, the tensor is not collapsable. So, consider + // all the dimensions of the producer to be uncollapsable. + if (!opMap.contains(producerOp)) { + didChange |= consumerInfo.updateFromUncollapsible(consumerOp, operand); continue; } const CollapseInfo &producerInfo = opMap.at(producerOp); - CollapseInfo::CollapsableLoopsSet producerCollapsable = - producerInfo.getTransformedCollapsableLoops(mapping.value()); - producerUncollapsable.set_subtract(producerCollapsable); - + FailureOr consumerToProducerMap = + getProducerLoopToConsumerLoopsMap(*operand); didChange |= - consumerInfo.updateCollapseViaSubtract(producerUncollapsable); + consumerInfo.updateFromOther(consumerToProducerMap, producerInfo); } } return didChange; @@ -837,21 +839,31 @@ updateProducersFromConsumers(ArrayRef slice, // Iterate over `slice` in reverse so that we visit each `op` 's consumer // before visiting `op`. for (auto op : llvm::reverse(slice)) { - auto consumerOp = cast(op); - const CollapseInfo &consumerInfo = opMap.at(consumerOp); + auto producerOp = cast(op); + CollapseInfo &producerInfo = opMap.find(producerOp)->second; - for (auto *operand : consumerOp.getDpsInputOperands()) { - auto definingOp = operand->get().getDefiningOp(); - if (!definingOp || !opMap.contains(definingOp)) { + for (auto &operand : producerOp->getUses()) { + auto *consumerOp = operand.getOwner(); + if (consumerOp->hasTrait()) { + continue; + } + + // If we can't find the op, the tensor is not collapsable. So, consider + // all the dimensions of the consumer to be uncollapsable. + if (!opMap.contains(consumerOp)) { + didChange |= producerInfo.updateFromUncollapsible(producerOp, &operand); continue; } // Get a mapping from the consumer's iteration space to the producer's. - CollapseInfo &producerInfo = opMap.find(definingOp)->second; + const CollapseInfo &consumerInfo = opMap.at(consumerOp); // Only loops collapsable in both the consumer and producer may be // collapsed. - didChange |= producerInfo.updateFromConsumer(operand, consumerInfo); + FailureOr consumerToProducerMap = + getConsumerLoopToProducerLoopsMap(operand); + didChange |= + producerInfo.updateFromOther(consumerToProducerMap, consumerInfo); } } return didChange; diff --git a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir index 880f7f99e0c0..5ae2cb71df1b 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir @@ -16,7 +16,7 @@ util.func public @do_not_collapse_cst_in_place(%arg0: tensor<1x1x2304xf32>) { util.return } // CHECK-LABEL: util.func public @do_not_collapse_cst_in_place -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]]] +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] // CHECK-DAG: %[[CST:.+]] = arith.constant // CHECK-DAG: %[[COLLAPSED_ARG0:.+]] = tensor.collapse_shape %[[ARG0]] // CHECK-DAG: %[[COLLAPSED_CST:.+]] = tensor.collapse_shape %[[CST]] @@ -656,3 +656,100 @@ util.func public @collapse(%10: tensor<2x32x32x1280xi8>, %11 : tensor<10240x1280 // CHECK: %[[GEN1:.*]] = linalg.generic // CHECK-SAME: iterator_types = ["parallel", "parallel"] // CHECK: flow.return %[[GEN1]] : tensor<2048x10240xf16> + +// ----- + +util.func public @update_from_producer(%arg0: tensor<2x1x256x16x16xi8>, %arg1: tensor<2x1x256xf32>) -> tensor<1x256x16x16xi8> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = flow.dispatch.region -> (tensor<1x256x16x16xi8>) { + %1 = tensor.empty() : tensor<1x256x16x16xi8> + %2 = tensor.empty() : tensor<1x256x16x16xf32> + %3 = tensor.empty() : tensor<2x1x256x16x16xf32> + %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x1x256x16x16xi8>) outs(%3 : tensor<2x1x256x16x16xf32>) { + ^bb0(%in: i8, %out: f32): + %8 = arith.extsi %in : i8 to i32 + %9 = arith.sitofp %8 : i32 to f32 + linalg.yield %9 : f32 + } -> tensor<2x1x256x16x16xf32> + %5 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1x256x16x16xf32>) -> tensor<1x256x16x16xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%4, %arg1 : tensor<2x1x256x16x16xf32>, tensor<2x1x256xf32>) outs(%5 : tensor<1x256x16x16xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %8 = arith.mulf %in, %in_0 : f32 + %9 = arith.addf %8, %out : f32 + linalg.yield %9 : f32 + } -> tensor<1x256x16x16xf32> + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : tensor<1x256x16x16xf32>) outs(%1 : tensor<1x256x16x16xi8>) { + ^bb0(%in: f32, %out: i8): + %8 = arith.fptosi %in : f32 to i8 + linalg.yield %8 : i8 + } -> tensor<1x256x16x16xi8> + flow.return %7 : tensor<1x256x16x16xi8> + } + util.return %0 : tensor<1x256x16x16xi8> +} + +// CHECK-LABEL: util.func public @update_from_producer +// CHECK: %[[GEN0:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK: %[[GEN1:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: ins(%[[GEN0]] +// CHECK: %[[GEN2:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[GEN1]] +// CHECK: flow.return %[[GEN2]] : tensor<256x256xi8> + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +util.func public @uncollapsable_consumer(%arg0: tensor<1x1x2304xf32>) { + %cst = arith.constant dense<0.000000e+00> : tensor<1x1x2304xf32> + %0 = tensor.empty() : tensor<1x1x2304xf32> + %1 = flow.dispatch.region -> (tensor<1x1x2304xf32>) { + %2 = tensor.empty() : tensor<1x1x2304xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst : tensor<1x1x2304xf32>, tensor<1x1x2304xf32>) outs(%2 : tensor<1x1x2304xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<1x1x2304xf32> + %10 = util.optimization_barrier %3 : tensor<1x1x2304xf32> + flow.return %3 : tensor<1x1x2304xf32> + } + util.return +} +// CHECK-LABEL: util.func public @uncollapsable_consumer +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] +// CHECK-DAG: %[[CST:.+]] = arith.constant +// CHECK: %{{.+}} = flow.dispatch.region +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[CST]] +// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[RES]] +// CHECK: flow.return %[[RES]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +util.func public @uncollapsable_consumer_partial(%arg0: tensor<10x20x30x2304xf32>) { + %cst = arith.constant dense<0.000000e+00> : tensor<10x20x30x2304xf32> + %0 = tensor.empty() : tensor<30x2304xf32> + %1 = flow.dispatch.region -> (tensor<30x2304xf32>) { + %2 = tensor.empty() : tensor<30x2304xf32> + %3 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %cst : tensor<10x20x30x2304xf32>, tensor<10x20x30x2304xf32>) outs(%2 : tensor<30x2304xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<30x2304xf32> + %10 = util.optimization_barrier %3 : tensor<30x2304xf32> + flow.return %3 : tensor<30x2304xf32> + } + util.return +} +// CHECK-LABEL: util.func public @uncollapsable_consumer_partial +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] +// CHECK-DAG: %[[CST:.+]] = arith.constant +// CHECK: %{{.+}} = flow.dispatch.region +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[RES]] +// CHECK: flow.return %[[RES]]