From 94ecb8bde107ff49ba141cec0fb7468d8c450c76 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Tue, 9 Jul 2024 12:54:42 -0700 Subject: [PATCH] [NFC] Modify method for characterizing bit-extension operations to handle charecterization of bit-truncation as well. (#17833) Currently the method for bit extension (a.k.a dequantization ops) only returns a `bool`. Make this method return a richer handle, which can also allow classification of bit-truncation operations. Also rename `isDequantization` method to `isBitExtend` operation. Signed-off-by: MaheshRavishankar --- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 4 +- .../Flow/Transforms/BubbleUpExpandShapes.cpp | 2 +- .../Flow/Transforms/FormDispatchRegions.cpp | 6 +- .../FuseMultiUseElementwiseProducer.cpp | 4 +- .../Flow/Transforms/FusionPreprocessing.cpp | 2 +- .../Dialect/Flow/Transforms/FusionUtils.cpp | 40 ++++----- .../Dialect/Flow/Transforms/RegionOpUtils.cpp | 89 +++++++++++++------ .../Dialect/Flow/Transforms/RegionOpUtils.h | 42 +++++++-- .../Dialect/Flow/Transforms/SinkReshapes.cpp | 2 +- 9 files changed, 128 insertions(+), 63 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index ed9175eb3a3a..5879ac0de2e4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -447,11 +447,11 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, Type initElemType = getElementTypeOrSelf(init); if (auto lhsOp = lhs.getDefiningOp()) { - if (IREE::Flow::isDequantizationLikeOp(lhsOp)) + if (IREE::Flow::isBitExtendOp(lhsOp)) lhsElemType = getElementTypeOrSelf(lhsOp.getDpsInputs()[0]); } if (auto rhsOp = rhs.getDefiningOp()) { - if (IREE::Flow::isDequantizationLikeOp(rhsOp)) + if (IREE::Flow::isBitExtendOp(rhsOp)) rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp index e49b9ec6c291..bc108c25ac6f 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp @@ -53,7 +53,7 @@ void BubbleUpExpandShapesPass::runOnOperation() { } // Do not fuse by expand if consumer is dequant. - if (isDequantizationLikeOp(consumer)) { + if (isBitExtendOp(consumer)) { return false; } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp index 2a9d362e787b..77af6c98276b 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp @@ -226,7 +226,7 @@ static bool isRootOp(Operation *op) { return false; } // Dequantization-like ops get cloned into dispatches later. - if (isDequantizationLikeOp(op)) { + if (isBitExtendOp(op)) { return false; } // Any Linalg named op or generic op with reduction iterator types is a root @@ -539,7 +539,7 @@ isFusableWithConsumer(OpOperand &fusedOperand, // If consumer is a dequant operation, dont fuse it. These get cloned // into their consumers. - if (isDequantizationLikeOp(consumer)) { + if (isBitExtendOp(consumer)) { return false; } @@ -874,7 +874,7 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, // materializing large tensors between dispatches. if (!isa(op) || - isa(op) || isDequantizationLikeOp(&op)) { + isa(op) || isBitExtendOp(&op)) { continue; } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp index 7c36845ac308..c2ea01edce32 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp @@ -156,7 +156,7 @@ static FailureOr fuseMultiUseProducers(Operation *funcOp, // Dequantization-like operations should be fused with consumers to keep // the smaller bit width on the dispatch boundary. - if (isDequantizationLikeOp(genericOp)) { + if (isBitExtendOp(genericOp)) { return; } @@ -196,7 +196,7 @@ static FailureOr fuseMultiUseProducers(Operation *funcOp, // 7. Skip dequantization-like `producer` ops as we would rather fuse // by cloning the producer instead of multi-use fusion. - if (isDequantizationLikeOp(producer)) { + if (isBitExtendOp(producer)) { return; } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp index 68f6d7f7cb2f..4aa3fc12f432 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp @@ -195,7 +195,7 @@ struct GatherFusionPattern : public OpRewritePattern { // Check if the producerOp is fusible if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 || - !isElementwise(producerOp) || !isDequantizationLikeOp(producerOp)) { + !isElementwise(producerOp) || !isBitExtendOp(producerOp)) { return rewriter.notifyMatchFailure(producerOp, "producer op is not fusible"); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp index 8099518ea348..688f536e6a70 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp @@ -53,21 +53,28 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, } // If producer does not have a single user, dont fuse. - if (!producerOp->hasOneUse()) + if (!producerOp->hasOneUse()) { return false; + } - // Do no fuse dequantization-like operations with producers. The - // dequantization ops are cloned into all their use dispatches. So fusing - // producer with consumer here would then result in producer also getting - // cloned into many dispatches which is against the thumb rule of fusion to - // not introduce additional computation (except for dequant ops). If the - // consumer has only one use, then this fusion is fine since cloning wont - // result in redundant computation of the producer. (Also note that the - // producer is always an elementwise operation). - if (isDequantizationLikeOp(consumerOp) && !consumerOp->hasOneUse()) { + std::optional consumerBitwidthChangeInfo = + isBitExtendOrTruncateOp(consumerOp); + // Do no fuse bitextend-like operations with producers. Such ops are cloned + // into all their use dispatches. So fusing producer with consumer here would + // then result in producer also getting cloned into many dispatches which is + // against the thumb rule of fusion to not introduce additional computation + // (except for bit-extend ops). If the consumer has only one use, then this + // fusion is fine since cloning wont result in redundant computation of the + // producer. (Also note that the producer is always an elementwise operation). + if (consumerBitwidthChangeInfo && + consumerBitwidthChangeInfo->isExtensionOp() && !consumerOp->hasOneUse()) { return false; } + auto linalgConsumerOp = dyn_cast(consumerOp); + if (!linalgConsumerOp) { + return false; + } // If the producer has a single use (this op), only fuse if // - 1) The consumer op is all parallel loops. The parallelism of the consumer // can be used as a way to amortize cost of redundant computation @@ -75,12 +82,8 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, // consumer for the producer result is a permutation. If it is a // broadcast this ends up redundantly computing operations without more // parallelism. - if (auto linalgConsumerOp = dyn_cast(consumerOp)) { - - if (linalgConsumerOp.getNumParallelLoops() == - linalgConsumerOp.getNumLoops()) { - return true; - } + if (linalgConsumerOp.getNumParallelLoops() != + linalgConsumerOp.getNumLoops()) { if (!linalgConsumerOp.getMatchingIndexingMap(fusedOperand) .isPermutation()) { return false; @@ -92,11 +95,8 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, linalg::isaConvolutionOpInterface(linalgConsumerOp)) { return false; } - return true; } - - // All other cases dont fuse. - return false; + return true; } } // namespace mlir::iree_compiler::IREE::Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp index 4c9eda8ddeb5..3b6084fa2e90 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp @@ -527,66 +527,105 @@ wrapOpInDispatchRegion(RewriterBase &rewriter, Operation *op) { return newRegionOp; } -bool isDequantizationLikeOp(Operation *op) { +//===---------------------------------------------------------------------===// +// Classification of ops that change bit-widths +//===---------------------------------------------------------------------===// + +Type BitWidthChangeInfo::getInputElementType() const { + return cast(inputOperand->get().getType()).getElementType(); +} + +std::optional isBitExtendOrTruncateOp(Operation *op) { auto genericOp = dyn_cast(op); if (!genericOp) { - return false; + return std::nullopt; } if (genericOp.getNumDpsInits() != 1) { - return false; + return std::nullopt; } // Check that the all loops are parallel unsigned numLoops = genericOp.getNumLoops(); unsigned numParallelLoops = genericOp.getNumParallelLoops(); if (numLoops != numParallelLoops) { - return false; + return std::nullopt; } // Check that all operands that have the highest rank have bit width // less than the output bit-width. - DenseMap> rankBuckets; - int64_t maxRank = 0; + DenseMap> rankBuckets; + int64_t maxOperandRank = 0; for (OpOperand *input : genericOp.getDpsInputOperands()) { auto inputType = dyn_cast(input->get().getType()); if (!inputType) { continue; } int64_t currRank = inputType.getRank(); - maxRank = std::max(currRank, maxRank); - rankBuckets[currRank].push_back(inputType); + maxOperandRank = std::max(currRank, maxOperandRank); + rankBuckets[currRank].push_back(input); } - if (rankBuckets[maxRank].empty()) { - return false; + if (maxOperandRank == 0 || rankBuckets[maxOperandRank].empty()) { + return std::nullopt; } unsigned int maxInputElementBitWidth = 0; - for (auto t : rankBuckets[maxRank]) { - Type elementType = t.getElementType(); + OpOperand *inputOperand; + for (OpOperand *operand : rankBuckets[maxOperandRank]) { + RankedTensorType tensorType = + cast(operand->get().getType()); + Type elementType = tensorType.getElementType(); if (!elementType.isIntOrFloat()) { - return false; + return std::nullopt; + } + unsigned elementBitWidth = Util::getTypeBitWidth(elementType); + if (elementBitWidth > maxInputElementBitWidth) { + maxInputElementBitWidth = elementBitWidth; + inputOperand = operand; } - maxInputElementBitWidth = - std::max(maxInputElementBitWidth, elementType.getIntOrFloatBitWidth()); } + if (!inputOperand) { + return std::nullopt; + } + Type inputElementType = + cast(inputOperand->get().getType()).getElementType(); // Check that the identity input element bitwidth is smaller than the output // element bitwidth. - Type outputElementType = getElementTypeOrSelf(genericOp->getResultTypes()[0]); + RankedTensorType outputType = + dyn_cast(genericOp->getResultTypes()[0]); + if (!outputType) { + return std::nullopt; + } + Type outputElementType = outputType.getElementType(); if (!outputElementType.isIntOrFloat()) { - return false; + return std::nullopt; } - if (maxInputElementBitWidth >= outputElementType.getIntOrFloatBitWidth()) { - return false; + + unsigned inputBitWidth = Util::getTypeBitWidth(inputElementType); + unsigned outputBitWidth = Util::getTypeBitWidth(outputElementType); + if (inputBitWidth == outputBitWidth) { + return std::nullopt; } - // Check if there are any operations from math dialect. - for (Operation &op : *genericOp.getBody()) { - if (op.getDialect() == op.getContext()->getLoadedDialect("math")) { - return false; + // Checks specific to bit extend operations. + if (inputBitWidth < outputBitWidth) { + // Since these are cloned into dispatches, avoid expensive operations. + for (Operation &op : *genericOp.getBody()) { + if (op.getDialect() == op.getContext()->getLoadedDialect("math")) { + return std::nullopt; + } } } - return true; + + // Checks specific to bit truncate operations. + if (outputBitWidth < inputBitWidth) { + // For now enforce that the input and output ranks match for truncates. + if (maxOperandRank != outputType.getRank()) { + return std::nullopt; + } + } + + return BitWidthChangeInfo{inputOperand, outputElementType}; } //===---------------------------------------------------------------------===// @@ -604,7 +643,7 @@ bool isClonableIntoDispatchOp(Operation *op) { tensor::ExtractSliceOp, complex::CreateOp>(op)) { return true; } - if (isDequantizationLikeOp(op)) { + if (isBitExtendOp(op)) { return true; } if (isa(op) || isa(op)) { diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h index e458ef7fd66d..df8792bc2753 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h @@ -104,15 +104,41 @@ FailureOr wrapOpInDispatchRegion(RewriterBase &rewriter, /// into a dispatch region. bool isClonableIntoDispatchOp(Operation *op); -/// Returns true if the operation has dequantization-like properties. +/// Returns true if the operation increases/decreases bitwidths of tensors. /// This function checks that the genericOp: -/// 1. Has only one output, and the output has an identity indexing map -/// 2. Has all parallel loops. -/// 3. Has exactly one input with an identity indexing map. -/// 4. All other inputs are projected permutations and not permutations. -/// 5. The input with an identity indexing map has a smaller element -/// bitwidth than the output -bool isDequantizationLikeOp(Operation *op); +/// 1. Has only one output. +/// 2. Has all parallel loops. +/// 3. Compared to the element type of the input with highest rank, +/// the output element type has either a higher or lower bitwidth. +struct BitWidthChangeInfo { + // The operand the recognizer treats as the "input". + // Is guaranteed to be a `RankedTensorType`. + OpOperand *inputOperand = nullptr; + // The output element type is int or float type. + Type outputElementType = nullptr; + + // Helper methods. + Type getInputElementType() const; + bool isExtensionOp() const { + return getInputElementType().getIntOrFloatBitWidth() < + outputElementType.getIntOrFloatBitWidth(); + } + bool isTruncationOp() const { + return outputElementType.getIntOrFloatBitWidth() < + getInputElementType().getIntOrFloatBitWidth(); + } +}; +std::optional isBitExtendOrTruncateOp(Operation *op); +inline bool isBitExtendOp(Operation *op) { + std::optional bitWidthChangeInfo = + isBitExtendOrTruncateOp(op); + return bitWidthChangeInfo && bitWidthChangeInfo->isExtensionOp(); +} +inline bool isBitTruncateOp(Operation *op) { + std::optional bitWidthChangeInfo = + isBitExtendOrTruncateOp(op); + return bitWidthChangeInfo && bitWidthChangeInfo->isTruncationOp(); +} /// Collect all ops that should be cloned into the given dispatch region op. SmallVector getCloneableOps(Flow::DispatchRegionOp regionOp); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp index d68320de1ab3..9f0b383ffc0a 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp @@ -76,7 +76,7 @@ static bool shouldSinkExpandShapeOp(OpOperand *opOperand) { // Do not sink reshapes across dequantize operations since they are // cloned into their consumers. - if (isDequantizationLikeOp(consumer)) { + if (isBitExtendOp(consumer)) { return false; }