Skip to content

Commit

Permalink
[DT][NFC] Refactor encoding utilities. (1/n) (iree-org#19310)
Browse files Browse the repository at this point in the history
The revision shuffles the utilities to the Encoding dialect and the
Codegen dialect:

1. Move TileMxNxK struct and getEncodingInfoForMatmul method to the
Codegen dialect (i.e., `Dialect/Codegen/*`)
2. Move isNarrowNResult to the Encoding dialect because it does not
depend on any other dialects other than the Encoding dialect.
3. Move lowerContractionOpWithEncoding to Codegen dialect utils for the
preparation. All the materialization logic will be moved to Codegen
dialect; they share the utilities during the transition period.

To accomplish (3), the revision introduces ResolveEncodingInfoFn
function type, which decouple the dependency from
MaterializeEncodingTypeConvert. It is a requirement because the type
converter uses HAL while we don't want the Codegen dialect depending on
HAL. We do not need the dependency once we move all the logic to
attribute implementation.

Minor cleanups:

- Remove the `rank` argument from getEncodingInfoForMatmul. It is not
used at all.
- Add the `static` keyword to the local `getExpandedType` function.

Note that the `lowerSetEncodingOpToPackOp` and
`lowerUnsetEncodingToUnpackOp` functions are not moved because it
requires more changes. They will be moved in a separate patch.

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW authored Nov 27, 2024
1 parent 516ff10 commit 991594e
Show file tree
Hide file tree
Showing 12 changed files with 271 additions and 234 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
namespace mlir::iree_compiler {

using IREE::Codegen::MaterializeEncodingInfo;
using IREE::Codegen::TileMxNxK;

#define GEN_PASS_DEF_CPUMATERIALIZEDEVICEENCODINGPASS
#define GEN_PASS_DEF_CPUMATERIALIZEHOSTENCODINGPASS
Expand Down Expand Up @@ -445,8 +446,7 @@ materializeEncodingForTarget(RankedTensorType tensorType,

// Map the matmul TileMxNxK to an actual tile shape for the tensor at hand,
// based on its operand index in the matmul.
auto rank = tensorType.getRank();
return getEncodingInfoForMatmul(encoding, rank, chosenTileMxNxK);
return IREE::Codegen::getEncodingInfoForMatmul(encoding, chosenTileMxNxK);
}

static FailureOr<MaterializeEncodingValueInfo>
Expand Down
48 changes: 0 additions & 48 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,52 +153,4 @@ RankedTensorType dropEncoding(RankedTensorType type) {
return RankedTensorType::get(type.getShape(), type.getElementType());
}

MaterializeEncodingInfo getEncodingInfoForMatmul(EncodingAttr encoding,
int64_t rank,
TileMxNxK tileMxNxK) {
MaterializeEncodingInfo encodingInfo;
auto cDims = getEncodingContractionDims(encoding);
// The following expects M, N, K, and Batch sizes of at most 1 for now
assert(cDims->m.size() <= 1 && cDims->n.size() <= 1 && cDims->k.size() == 1 &&
cDims->batch.size() <= 1 &&
"Expected at most one M, N, K, and Batch dimension");
std::optional<unsigned> batchDim =
cDims->batch.empty() ? std::nullopt
: encoding.mapDimToOperandIndex(cDims->batch[0]);
std::optional<unsigned> mDim =
cDims->m.empty() ? std::nullopt
: encoding.mapDimToOperandIndex(cDims->m[0]);
std::optional<unsigned> nDim =
cDims->n.empty() ? std::nullopt
: encoding.mapDimToOperandIndex(cDims->n[0]);
std::optional<unsigned> kDim = encoding.mapDimToOperandIndex(cDims->k[0]);
if (batchDim.has_value()) {
encodingInfo.outerDimsPerm.push_back(batchDim.value());
}
if (mDim.has_value()) {
encodingInfo.outerDimsPerm.push_back(mDim.value());
encodingInfo.innerDimsPos.push_back(mDim.value());
encodingInfo.innerTileSizes.push_back(tileMxNxK.M);
}
if (nDim.has_value()) {
encodingInfo.outerDimsPerm.push_back(nDim.value());
encodingInfo.innerDimsPos.push_back(nDim.value());
encodingInfo.innerTileSizes.push_back(tileMxNxK.N);
}
if (kDim.has_value()) {
encodingInfo.outerDimsPerm.push_back(kDim.value());
encodingInfo.innerDimsPos.push_back(kDim.value());
encodingInfo.innerTileSizes.push_back(tileMxNxK.K);
}
return encodingInfo;
}

bool isNarrowNResult(EncodingAttr encoding) {
if (encoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RESULT) {
return false;
}

return IREE::Encoding::getMatmulNarrowDim(encoding).isN();
}

} // namespace mlir::iree_compiler
14 changes: 0 additions & 14 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,6 @@ class OpMaterializeEncodingPattern : public OpConversionPattern<OpTy> {
/// Returns the RankedTensorType without encodings.
RankedTensorType dropEncoding(RankedTensorType type);

struct TileMxNxK {
int64_t M = 1;
int64_t N = 1;
int64_t K = 1;
};

IREE::Codegen::MaterializeEncodingInfo
getEncodingInfoForMatmul(IREE::Encoding::EncodingAttr encoding, int64_t rank,
TileMxNxK tileMxNxK);

/// Utility method to convert from `set_encoding` op to `pack` operation.
/// For now this takes a `paddingValue` as input. The source is also taken
/// as input so that these could be used with `OpConversionPatterns`.
Expand Down Expand Up @@ -126,10 +116,6 @@ void populateShapeIndependentMaterializeEncodingPatterns(
MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn);

// Returns true if `encoding` represents a narrow-N matmul RESULT, e.g. the
// result of a matvec.
bool isNarrowNResult(IREE::Encoding::EncodingAttr encoding);

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGUTILS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace mlir::iree_compiler {
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"

using IREE::Codegen::MaterializeEncodingInfo;
using IREE::Codegen::TileMxNxK;
using IREE::Codegen::TileSwizzle;

static IREE::GPU::MMAAttr chooseIntrinsicMMAAttr(TypeRange eTypes,
Expand Down Expand Up @@ -245,10 +246,10 @@ materializeEncodingForTarget(RankedTensorType tensorType,

// Map the matmul TileMxNxK to an actual tile shape for the tensor at hand,
// based on its operand index in the matmul.
auto rank = tensorType.getRank();
TileMxNxK innerTile;
std::tie(innerTile.M, innerTile.N, innerTile.K) = mma.getMNKShape();
auto encodingInfo = getEncodingInfoForMatmul(encoding, rank, innerTile);
auto encodingInfo =
IREE::Codegen::getEncodingInfoForMatmul(encoding, innerTile);
auto fragment =
static_cast<IREE::GPU::MMAFragment>(encoding.getOperandIndex().getInt());
encodingInfo.swizzle = getSwizzle(mma, fragment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,6 @@ getSwizzledShape(ArrayRef<OpFoldResult> packedShape,
return newShape;
}

static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op,
ValueRange convertedInputOperands,
ValueRange convertedOutputOperands) {
SmallVector<Value> operands;
operands.append(convertedInputOperands.begin(), convertedInputOperands.end());
operands.append(convertedOutputOperands.begin(),
convertedOutputOperands.end());
return mlir::clone(builder, op,
{dropEncoding(cast<RankedTensorType>(
convertedOutputOperands[0].getType()))},
operands);
}

static FailureOr<SmallVector<OpFoldResult>>
getInnerTileSizesOfr(OpBuilder &rewriter, Location loc,
RankedTensorType tensorType,
Expand Down Expand Up @@ -111,91 +98,6 @@ getInnerTileSizesOfr(OpBuilder &rewriter, Location loc,
return result;
}

RankedTensorType getExpandedType(RankedTensorType type, bool isBatched,
bool isTransposed,
SmallVectorImpl<ReassociationIndices> &ri) {
if (!isBatched) {
ri.assign({{0, 1}, {2, 3}});
if (!isTransposed) {
return RankedTensorType::get(
{1, type.getDimSize(0), 1, type.getDimSize(1)},
type.getElementType());
}
return RankedTensorType::get({type.getDimSize(0), 1, type.getDimSize(1), 1},
type.getElementType());
}

ri.assign({{0}, {1, 2}, {3, 4}});
if (!isTransposed) {
return RankedTensorType::get(
{type.getDimSize(0), 1, type.getDimSize(1), 1, type.getDimSize(2)},
type.getElementType());
}
return RankedTensorType::get(
{type.getDimSize(0), type.getDimSize(1), 1, type.getDimSize(2), 1},
type.getElementType());
}

/// Given an input Value and a desired output element type, create and return
/// an element-wise linalg::GenericOp that extends the input Value to the
/// output element type.
static Value createElementWiseExtUIOp(RewriterBase &rewriter, Value input,
Location loc, Type outElemType) {
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<AffineMap> maps(
2, rewriter.getMultiDimIdentityMap(inputType.getRank()));
SmallVector<utils::IteratorType> iteratorTypes(inputType.getRank(),
utils::IteratorType::parallel);
auto castedType = inputType.clone(outElemType);
SmallVector<OpFoldResult> inputMixedSizes =
tensor::getMixedSizes(rewriter, loc, input);
Value init =
rewriter.create<tensor::EmptyOp>(loc, inputMixedSizes, outElemType);
return rewriter
.create<linalg::GenericOp>(
loc, castedType, input, init, maps, iteratorTypes,
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
Value castRes =
b.create<arith::ExtUIOp>(nestedLoc, outElemType, args[0])
->getResult(0);
b.create<linalg::YieldOp>(nestedLoc, castRes);
})
.getResult(0);
}

/// If needed, expand and the input Value, and return the resulting input with
/// the canonical mmt4d input shape. If the input element type is unsigned,
/// create a producer Linalg::GenericOp on the input that unsigned extends the
/// input to the output element type. This extension is required to keep the
/// unsignedness information on the input for ukernels. If `transpose` is true,
/// the `linalgOp`'s indexing maps are transposed.
static Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp,
bool transpose, RewriterBase &rewriter,
SmallVectorImpl<ReassociationIndices> &ri,
ArrayRef<Type> elemTypes, int operandIdx) {
assert(linalgOp.getNumDpsInputs() == 2);
assert(linalgOp.getNumDpsInits() == 1);
auto cDims = linalg::inferContractionDims(linalgOp);
Location loc = linalgOp->getLoc();
Value expandedValue = value;
// If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the
// operand is a vector and must be extended
if ((cDims->m.empty() && operandIdx != 1) ||
(cDims->n.empty() && operandIdx != 0)) {
auto type = cast<RankedTensorType>(value.getType());
RankedTensorType newType = getExpandedType(
type, /*isBatched=*/!cDims->batch.empty(),
/*isTransposed=*/operandIdx == 2 && (transpose ^ cDims->n.empty()), ri);
expandedValue =
rewriter.create<tensor::ExpandShapeOp>(loc, newType, value, ri);
}
if (elemTypes[operandIdx].isUnsignedInteger()) {
return createElementWiseExtUIOp(rewriter, expandedValue, loc,
elemTypes.back());
}
return expandedValue;
}

static void transposeInPlace(MaterializeEncodingInfo &info) {
// Vector cases: nothing to do.
if (info.innerTileSizes.size() < 2) {
Expand Down Expand Up @@ -297,75 +199,6 @@ FailureOr<tensor::UnPackOp> lowerUnsetEncodingToUnpackOp(
encodingInfo->outerDimsPerm);
}

static FailureOr<Operation *> lowerContractionOpWithEncoding(
RewriterBase &rewriter, linalg::LinalgOp linalgOp, ValueRange operands,
const MaterializeEncodingTypeConverter &typeConverter) {
if (!linalgOp.hasPureTensorSemantics())
return failure();

auto inputs = linalgOp.getDpsInputOperands();
auto outputs = linalgOp.getDpsInits();

auto lhsType = cast<RankedTensorType>(inputs[0]->get().getType());
auto rhsType = cast<RankedTensorType>(inputs[1]->get().getType());
auto resultType = cast<RankedTensorType>(outputs[0].getType());
auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType);
auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType);
auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType);
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
return failure();
}

if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS ||
rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS ||
resultEncoding.getOperandIndex().getValue() !=
IREE::Encoding::MATMUL_RESULT) {
return failure();
}

FailureOr<MaterializeEncodingInfo> encodingInfo =
typeConverter.getEncodingInfo(
cast<RankedTensorType>(linalgOp->getResultTypes()[0]));

Operation *result;
if (failed(encodingInfo)) {
result = dropEncodingAndCloneOp(rewriter, linalgOp,
operands.take_front(inputs.size()),
operands.drop_front(inputs.size()));
} else {
bool transpose =
typeConverter.getTransposeNarrowN() && isNarrowNResult(resultEncoding);
SmallVector<Type> elemTypes = lhsEncoding.getElementTypesArray();
SmallVector<ReassociationIndices> ri;
Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, rewriter,
ri, elemTypes, /*operandIdx=*/0);
Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, rewriter,
ri, elemTypes, /*operandIdx=*/1);
Value newResult =
getMmt4dOperand(operands[2], linalgOp, transpose, rewriter, ri,
elemTypes, /*operandIdx=*/2);
if (transpose) {
std::swap(newLhs, newRhs);
}
Type newResultType = newResult.getType();
auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding);
if (cDims->batch.empty()) {
result = rewriter.create<linalg::Mmt4DOp>(
linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
ValueRange{newResult});
} else {
result = rewriter.create<linalg::BatchMmt4DOp>(
linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
ValueRange{newResult});
}
if (!ri.empty()) {
result = rewriter.create<tensor::CollapseShapeOp>(
linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri);
}
}
return result;
}

/// Utility method to convert `tensor.empty` with encoding to a `tensor.empty`
/// of the materialized type.
static FailureOr<Operation *>
Expand Down Expand Up @@ -901,8 +734,17 @@ class MaterializeContractionOp

auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
this->getTypeConverter());
// TODO(hanchung): This is a transition state for moving the implementation
// details to backend attributes. We won't need the function type argument
// after all the backends that support encodings implement the attribute.
auto getEncodingInfoWrapper =
[&](RankedTensorType type) -> FailureOr<MaterializeEncodingInfo> {
return converter->getEncodingInfo(type);
};
FailureOr<Operation *> convertedOp =
lowerContractionOpWithEncoding(rewriter, op, operands, *converter);
IREE::Codegen::lowerContractionOpWithEncoding(
rewriter, op, operands, converter->getTransposeNarrowN(),
getEncodingInfoWrapper);
if (failed(convertedOp)) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <cstdint>

#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LLVM.h"

namespace mlir::iree_compiler::IREE::Codegen {
Expand Down Expand Up @@ -89,5 +90,8 @@ struct MaterializeEncodingInfo {
std::optional<TileSwizzle> swizzle;
};

using ResolveEncodingInfoFn =
std::function<FailureOr<MaterializeEncodingInfo>(RankedTensorType type)>;

} // namespace mlir::iree_compiler::IREE::Codegen
#endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_IR_IREECODEGENTYPES_H_
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ iree_compiler_cc_library(
],
deps = [
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:TensorDialect",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ iree_cc_library(
"Utils.cpp"
DEPS
LLVMSupport
MLIRArithDialect
MLIRIR
MLIRLinalgDialect
MLIRTensorDialect
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Dialect::Encoding::IR
PUBLIC
)

Expand Down
Loading

0 comments on commit 991594e

Please sign in to comment.