Skip to content

Commit

Permalink
[NFC] Modify method for characterizing bit-extension operations to ha…
Browse files Browse the repository at this point in the history
…ndle charecterization of bit-truncation as well. (iree-org#17833)

Currently the method for bit extension (a.k.a dequantization ops) only
returns a `bool`. Make this method return a richer handle, which can
also allow classification of bit-truncation operations. Also rename
`isDequantization` method to `isBitExtend` operation.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar authored Jul 9, 2024
1 parent 4d204ea commit 94ecb8b
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 63 deletions.
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,11 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
Type initElemType = getElementTypeOrSelf(init);

if (auto lhsOp = lhs.getDefiningOp<linalg::GenericOp>()) {
if (IREE::Flow::isDequantizationLikeOp(lhsOp))
if (IREE::Flow::isBitExtendOp(lhsOp))
lhsElemType = getElementTypeOrSelf(lhsOp.getDpsInputs()[0]);
}
if (auto rhsOp = rhs.getDefiningOp<linalg::GenericOp>()) {
if (IREE::Flow::isDequantizationLikeOp(rhsOp))
if (IREE::Flow::isBitExtendOp(rhsOp))
rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void BubbleUpExpandShapesPass::runOnOperation() {
}

// Do not fuse by expand if consumer is dequant.
if (isDequantizationLikeOp(consumer)) {
if (isBitExtendOp(consumer)) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ static bool isRootOp(Operation *op) {
return false;
}
// Dequantization-like ops get cloned into dispatches later.
if (isDequantizationLikeOp(op)) {
if (isBitExtendOp(op)) {
return false;
}
// Any Linalg named op or generic op with reduction iterator types is a root
Expand Down Expand Up @@ -539,7 +539,7 @@ isFusableWithConsumer(OpOperand &fusedOperand,

// If consumer is a dequant operation, dont fuse it. These get cloned
// into their consumers.
if (isDequantizationLikeOp(consumer)) {
if (isBitExtendOp(consumer)) {
return false;
}

Expand Down Expand Up @@ -874,7 +874,7 @@ decideFusableLinalgOps(Region &region, DominanceInfo const &dominanceInfo,
// materializing large tensors between dispatches.
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp,
IREE::Encoding::SetEncodingOp>(op) ||
isa<linalg::FillOp>(op) || isDequantizationLikeOp(&op)) {
isa<linalg::FillOp>(op) || isBitExtendOp(&op)) {
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,

// Dequantization-like operations should be fused with consumers to keep
// the smaller bit width on the dispatch boundary.
if (isDequantizationLikeOp(genericOp)) {
if (isBitExtendOp(genericOp)) {
return;
}

Expand Down Expand Up @@ -196,7 +196,7 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,

// 7. Skip dequantization-like `producer` ops as we would rather fuse
// by cloning the producer instead of multi-use fusion.
if (isDequantizationLikeOp(producer)) {
if (isBitExtendOp(producer)) {
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ struct GatherFusionPattern : public OpRewritePattern<tensor::ExtractOp> {

// Check if the producerOp is fusible
if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 ||
!isElementwise(producerOp) || !isDequantizationLikeOp(producerOp)) {
!isElementwise(producerOp) || !isBitExtendOp(producerOp)) {
return rewriter.notifyMatchFailure(producerOp,
"producer op is not fusible");
}
Expand Down
40 changes: 20 additions & 20 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,34 +53,37 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
}

// If producer does not have a single user, dont fuse.
if (!producerOp->hasOneUse())
if (!producerOp->hasOneUse()) {
return false;
}

// Do no fuse dequantization-like operations with producers. The
// dequantization ops are cloned into all their use dispatches. So fusing
// producer with consumer here would then result in producer also getting
// cloned into many dispatches which is against the thumb rule of fusion to
// not introduce additional computation (except for dequant ops). If the
// consumer has only one use, then this fusion is fine since cloning wont
// result in redundant computation of the producer. (Also note that the
// producer is always an elementwise operation).
if (isDequantizationLikeOp(consumerOp) && !consumerOp->hasOneUse()) {
std::optional<BitWidthChangeInfo> consumerBitwidthChangeInfo =
isBitExtendOrTruncateOp(consumerOp);
// Do no fuse bitextend-like operations with producers. Such ops are cloned
// into all their use dispatches. So fusing producer with consumer here would
// then result in producer also getting cloned into many dispatches which is
// against the thumb rule of fusion to not introduce additional computation
// (except for bit-extend ops). If the consumer has only one use, then this
// fusion is fine since cloning wont result in redundant computation of the
// producer. (Also note that the producer is always an elementwise operation).
if (consumerBitwidthChangeInfo &&
consumerBitwidthChangeInfo->isExtensionOp() && !consumerOp->hasOneUse()) {
return false;
}

auto linalgConsumerOp = dyn_cast<linalg::LinalgOp>(consumerOp);
if (!linalgConsumerOp) {
return false;
}
// If the producer has a single use (this op), only fuse if
// - 1) The consumer op is all parallel loops. The parallelism of the consumer
// can be used as a way to amortize cost of redundant computation
// - 2) If consumer op is a reduction, only fuse if the indexing map in the
// consumer for the producer result is a permutation. If it is a
// broadcast this ends up redundantly computing operations without more
// parallelism.
if (auto linalgConsumerOp = dyn_cast<linalg::LinalgOp>(consumerOp)) {

if (linalgConsumerOp.getNumParallelLoops() ==
linalgConsumerOp.getNumLoops()) {
return true;
}
if (linalgConsumerOp.getNumParallelLoops() !=
linalgConsumerOp.getNumLoops()) {
if (!linalgConsumerOp.getMatchingIndexingMap(fusedOperand)
.isPermutation()) {
return false;
Expand All @@ -92,11 +95,8 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
linalg::isaConvolutionOpInterface(linalgConsumerOp)) {
return false;
}
return true;
}

// All other cases dont fuse.
return false;
return true;
}

} // namespace mlir::iree_compiler::IREE::Flow
Original file line number Diff line number Diff line change
Expand Up @@ -527,66 +527,105 @@ wrapOpInDispatchRegion(RewriterBase &rewriter, Operation *op) {
return newRegionOp;
}

bool isDequantizationLikeOp(Operation *op) {
//===---------------------------------------------------------------------===//
// Classification of ops that change bit-widths
//===---------------------------------------------------------------------===//

Type BitWidthChangeInfo::getInputElementType() const {
return cast<RankedTensorType>(inputOperand->get().getType()).getElementType();
}

std::optional<BitWidthChangeInfo> isBitExtendOrTruncateOp(Operation *op) {
auto genericOp = dyn_cast<linalg::GenericOp>(op);
if (!genericOp) {
return false;
return std::nullopt;
}
if (genericOp.getNumDpsInits() != 1) {
return false;
return std::nullopt;
}

// Check that the all loops are parallel
unsigned numLoops = genericOp.getNumLoops();
unsigned numParallelLoops = genericOp.getNumParallelLoops();
if (numLoops != numParallelLoops) {
return false;
return std::nullopt;
}

// Check that all operands that have the highest rank have bit width
// less than the output bit-width.
DenseMap<int64_t, SmallVector<RankedTensorType>> rankBuckets;
int64_t maxRank = 0;
DenseMap<int64_t, SmallVector<OpOperand *>> rankBuckets;
int64_t maxOperandRank = 0;
for (OpOperand *input : genericOp.getDpsInputOperands()) {
auto inputType = dyn_cast<RankedTensorType>(input->get().getType());
if (!inputType) {
continue;
}
int64_t currRank = inputType.getRank();
maxRank = std::max(currRank, maxRank);
rankBuckets[currRank].push_back(inputType);
maxOperandRank = std::max(currRank, maxOperandRank);
rankBuckets[currRank].push_back(input);
}
if (rankBuckets[maxRank].empty()) {
return false;
if (maxOperandRank == 0 || rankBuckets[maxOperandRank].empty()) {
return std::nullopt;
}

unsigned int maxInputElementBitWidth = 0;
for (auto t : rankBuckets[maxRank]) {
Type elementType = t.getElementType();
OpOperand *inputOperand;
for (OpOperand *operand : rankBuckets[maxOperandRank]) {
RankedTensorType tensorType =
cast<RankedTensorType>(operand->get().getType());
Type elementType = tensorType.getElementType();
if (!elementType.isIntOrFloat()) {
return false;
return std::nullopt;
}
unsigned elementBitWidth = Util::getTypeBitWidth(elementType);
if (elementBitWidth > maxInputElementBitWidth) {
maxInputElementBitWidth = elementBitWidth;
inputOperand = operand;
}
maxInputElementBitWidth =
std::max(maxInputElementBitWidth, elementType.getIntOrFloatBitWidth());
}
if (!inputOperand) {
return std::nullopt;
}
Type inputElementType =
cast<RankedTensorType>(inputOperand->get().getType()).getElementType();

// Check that the identity input element bitwidth is smaller than the output
// element bitwidth.
Type outputElementType = getElementTypeOrSelf(genericOp->getResultTypes()[0]);
RankedTensorType outputType =
dyn_cast<RankedTensorType>(genericOp->getResultTypes()[0]);
if (!outputType) {
return std::nullopt;
}
Type outputElementType = outputType.getElementType();
if (!outputElementType.isIntOrFloat()) {
return false;
return std::nullopt;
}
if (maxInputElementBitWidth >= outputElementType.getIntOrFloatBitWidth()) {
return false;

unsigned inputBitWidth = Util::getTypeBitWidth(inputElementType);
unsigned outputBitWidth = Util::getTypeBitWidth(outputElementType);
if (inputBitWidth == outputBitWidth) {
return std::nullopt;
}

// Check if there are any operations from math dialect.
for (Operation &op : *genericOp.getBody()) {
if (op.getDialect() == op.getContext()->getLoadedDialect("math")) {
return false;
// Checks specific to bit extend operations.
if (inputBitWidth < outputBitWidth) {
// Since these are cloned into dispatches, avoid expensive operations.
for (Operation &op : *genericOp.getBody()) {
if (op.getDialect() == op.getContext()->getLoadedDialect("math")) {
return std::nullopt;
}
}
}
return true;

// Checks specific to bit truncate operations.
if (outputBitWidth < inputBitWidth) {
// For now enforce that the input and output ranks match for truncates.
if (maxOperandRank != outputType.getRank()) {
return std::nullopt;
}
}

return BitWidthChangeInfo{inputOperand, outputElementType};
}

//===---------------------------------------------------------------------===//
Expand All @@ -604,7 +643,7 @@ bool isClonableIntoDispatchOp(Operation *op) {
tensor::ExtractSliceOp, complex::CreateOp>(op)) {
return true;
}
if (isDequantizationLikeOp(op)) {
if (isBitExtendOp(op)) {
return true;
}
if (isa<arith::ConstantOp>(op) || isa<complex::ConstantOp>(op)) {
Expand Down
42 changes: 34 additions & 8 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,41 @@ FailureOr<Flow::DispatchRegionOp> wrapOpInDispatchRegion(RewriterBase &rewriter,
/// into a dispatch region.
bool isClonableIntoDispatchOp(Operation *op);

/// Returns true if the operation has dequantization-like properties.
/// Returns true if the operation increases/decreases bitwidths of tensors.
/// This function checks that the genericOp:
/// 1. Has only one output, and the output has an identity indexing map
/// 2. Has all parallel loops.
/// 3. Has exactly one input with an identity indexing map.
/// 4. All other inputs are projected permutations and not permutations.
/// 5. The input with an identity indexing map has a smaller element
/// bitwidth than the output
bool isDequantizationLikeOp(Operation *op);
/// 1. Has only one output.
/// 2. Has all parallel loops.
/// 3. Compared to the element type of the input with highest rank,
/// the output element type has either a higher or lower bitwidth.
struct BitWidthChangeInfo {
// The operand the recognizer treats as the "input".
// Is guaranteed to be a `RankedTensorType`.
OpOperand *inputOperand = nullptr;
// The output element type is int or float type.
Type outputElementType = nullptr;

// Helper methods.
Type getInputElementType() const;
bool isExtensionOp() const {
return getInputElementType().getIntOrFloatBitWidth() <
outputElementType.getIntOrFloatBitWidth();
}
bool isTruncationOp() const {
return outputElementType.getIntOrFloatBitWidth() <
getInputElementType().getIntOrFloatBitWidth();
}
};
std::optional<BitWidthChangeInfo> isBitExtendOrTruncateOp(Operation *op);
inline bool isBitExtendOp(Operation *op) {
std::optional<BitWidthChangeInfo> bitWidthChangeInfo =
isBitExtendOrTruncateOp(op);
return bitWidthChangeInfo && bitWidthChangeInfo->isExtensionOp();
}
inline bool isBitTruncateOp(Operation *op) {
std::optional<BitWidthChangeInfo> bitWidthChangeInfo =
isBitExtendOrTruncateOp(op);
return bitWidthChangeInfo && bitWidthChangeInfo->isTruncationOp();
}

/// Collect all ops that should be cloned into the given dispatch region op.
SmallVector<Operation *> getCloneableOps(Flow::DispatchRegionOp regionOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ static bool shouldSinkExpandShapeOp(OpOperand *opOperand) {

// Do not sink reshapes across dequantize operations since they are
// cloned into their consumers.
if (isDequantizationLikeOp(consumer)) {
if (isBitExtendOp(consumer)) {
return false;
}

Expand Down

0 comments on commit 94ecb8b

Please sign in to comment.