diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp index 18fabd49a4a5..c7517d8bca1a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp @@ -31,6 +31,7 @@ namespace mlir::iree_compiler { using IREE::Codegen::MaterializeEncodingInfo; +using IREE::Codegen::TileMxNxK; #define GEN_PASS_DEF_CPUMATERIALIZEDEVICEENCODINGPASS #define GEN_PASS_DEF_CPUMATERIALIZEHOSTENCODINGPASS @@ -445,8 +446,7 @@ materializeEncodingForTarget(RankedTensorType tensorType, // Map the matmul TileMxNxK to an actual tile shape for the tensor at hand, // based on its operand index in the matmul. - auto rank = tensorType.getRank(); - return getEncodingInfoForMatmul(encoding, rank, chosenTileMxNxK); + return IREE::Codegen::getEncodingInfoForMatmul(encoding, chosenTileMxNxK); } static FailureOr diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp index 7d041d09a738..fd75e74a987e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp @@ -153,52 +153,4 @@ RankedTensorType dropEncoding(RankedTensorType type) { return RankedTensorType::get(type.getShape(), type.getElementType()); } -MaterializeEncodingInfo getEncodingInfoForMatmul(EncodingAttr encoding, - int64_t rank, - TileMxNxK tileMxNxK) { - MaterializeEncodingInfo encodingInfo; - auto cDims = getEncodingContractionDims(encoding); - // The following expects M, N, K, and Batch sizes of at most 1 for now - assert(cDims->m.size() <= 1 && cDims->n.size() <= 1 && cDims->k.size() == 1 && - cDims->batch.size() <= 1 && - "Expected at most one M, N, K, and Batch dimension"); - std::optional batchDim = - cDims->batch.empty() ? std::nullopt - : encoding.mapDimToOperandIndex(cDims->batch[0]); - std::optional mDim = - cDims->m.empty() ? std::nullopt - : encoding.mapDimToOperandIndex(cDims->m[0]); - std::optional nDim = - cDims->n.empty() ? std::nullopt - : encoding.mapDimToOperandIndex(cDims->n[0]); - std::optional kDim = encoding.mapDimToOperandIndex(cDims->k[0]); - if (batchDim.has_value()) { - encodingInfo.outerDimsPerm.push_back(batchDim.value()); - } - if (mDim.has_value()) { - encodingInfo.outerDimsPerm.push_back(mDim.value()); - encodingInfo.innerDimsPos.push_back(mDim.value()); - encodingInfo.innerTileSizes.push_back(tileMxNxK.M); - } - if (nDim.has_value()) { - encodingInfo.outerDimsPerm.push_back(nDim.value()); - encodingInfo.innerDimsPos.push_back(nDim.value()); - encodingInfo.innerTileSizes.push_back(tileMxNxK.N); - } - if (kDim.has_value()) { - encodingInfo.outerDimsPerm.push_back(kDim.value()); - encodingInfo.innerDimsPos.push_back(kDim.value()); - encodingInfo.innerTileSizes.push_back(tileMxNxK.K); - } - return encodingInfo; -} - -bool isNarrowNResult(EncodingAttr encoding) { - if (encoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RESULT) { - return false; - } - - return IREE::Encoding::getMatmulNarrowDim(encoding).isN(); -} - } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h index 1c9d0860c5d8..7077fb6a05f1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h @@ -85,16 +85,6 @@ class OpMaterializeEncodingPattern : public OpConversionPattern { /// Returns the RankedTensorType without encodings. RankedTensorType dropEncoding(RankedTensorType type); -struct TileMxNxK { - int64_t M = 1; - int64_t N = 1; - int64_t K = 1; -}; - -IREE::Codegen::MaterializeEncodingInfo -getEncodingInfoForMatmul(IREE::Encoding::EncodingAttr encoding, int64_t rank, - TileMxNxK tileMxNxK); - /// Utility method to convert from `set_encoding` op to `pack` operation. /// For now this takes a `paddingValue` as input. The source is also taken /// as input so that these could be used with `OpConversionPatterns`. @@ -126,10 +116,6 @@ void populateShapeIndependentMaterializeEncodingPatterns( MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn materializeEncodingValueFn); -// Returns true if `encoding` represents a narrow-N matmul RESULT, e.g. the -// result of a matvec. -bool isNarrowNResult(IREE::Encoding::EncodingAttr encoding); - } // namespace mlir::iree_compiler #endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGUTILS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp index e32760b44215..6debc2a8ffbc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp @@ -43,6 +43,7 @@ namespace mlir::iree_compiler { #include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" using IREE::Codegen::MaterializeEncodingInfo; +using IREE::Codegen::TileMxNxK; using IREE::Codegen::TileSwizzle; static IREE::GPU::MMAAttr chooseIntrinsicMMAAttr(TypeRange eTypes, @@ -245,10 +246,10 @@ materializeEncodingForTarget(RankedTensorType tensorType, // Map the matmul TileMxNxK to an actual tile shape for the tensor at hand, // based on its operand index in the matmul. - auto rank = tensorType.getRank(); TileMxNxK innerTile; std::tie(innerTile.M, innerTile.N, innerTile.K) = mma.getMNKShape(); - auto encodingInfo = getEncodingInfoForMatmul(encoding, rank, innerTile); + auto encodingInfo = + IREE::Codegen::getEncodingInfoForMatmul(encoding, innerTile); auto fragment = static_cast(encoding.getOperandIndex().getInt()); encodingInfo.swizzle = getSwizzle(mma, fragment); diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp index 57a990b78bfc..fc3bb45c8be6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp @@ -61,19 +61,6 @@ getSwizzledShape(ArrayRef packedShape, return newShape; } -static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op, - ValueRange convertedInputOperands, - ValueRange convertedOutputOperands) { - SmallVector operands; - operands.append(convertedInputOperands.begin(), convertedInputOperands.end()); - operands.append(convertedOutputOperands.begin(), - convertedOutputOperands.end()); - return mlir::clone(builder, op, - {dropEncoding(cast( - convertedOutputOperands[0].getType()))}, - operands); -} - static FailureOr> getInnerTileSizesOfr(OpBuilder &rewriter, Location loc, RankedTensorType tensorType, @@ -111,91 +98,6 @@ getInnerTileSizesOfr(OpBuilder &rewriter, Location loc, return result; } -RankedTensorType getExpandedType(RankedTensorType type, bool isBatched, - bool isTransposed, - SmallVectorImpl &ri) { - if (!isBatched) { - ri.assign({{0, 1}, {2, 3}}); - if (!isTransposed) { - return RankedTensorType::get( - {1, type.getDimSize(0), 1, type.getDimSize(1)}, - type.getElementType()); - } - return RankedTensorType::get({type.getDimSize(0), 1, type.getDimSize(1), 1}, - type.getElementType()); - } - - ri.assign({{0}, {1, 2}, {3, 4}}); - if (!isTransposed) { - return RankedTensorType::get( - {type.getDimSize(0), 1, type.getDimSize(1), 1, type.getDimSize(2)}, - type.getElementType()); - } - return RankedTensorType::get( - {type.getDimSize(0), type.getDimSize(1), 1, type.getDimSize(2), 1}, - type.getElementType()); -} - -/// Given an input Value and a desired output element type, create and return -/// an element-wise linalg::GenericOp that extends the input Value to the -/// output element type. -static Value createElementWiseExtUIOp(RewriterBase &rewriter, Value input, - Location loc, Type outElemType) { - auto inputType = cast(input.getType()); - SmallVector maps( - 2, rewriter.getMultiDimIdentityMap(inputType.getRank())); - SmallVector iteratorTypes(inputType.getRank(), - utils::IteratorType::parallel); - auto castedType = inputType.clone(outElemType); - SmallVector inputMixedSizes = - tensor::getMixedSizes(rewriter, loc, input); - Value init = - rewriter.create(loc, inputMixedSizes, outElemType); - return rewriter - .create( - loc, castedType, input, init, maps, iteratorTypes, - [&](OpBuilder &b, Location nestedLoc, ValueRange args) { - Value castRes = - b.create(nestedLoc, outElemType, args[0]) - ->getResult(0); - b.create(nestedLoc, castRes); - }) - .getResult(0); -} - -/// If needed, expand and the input Value, and return the resulting input with -/// the canonical mmt4d input shape. If the input element type is unsigned, -/// create a producer Linalg::GenericOp on the input that unsigned extends the -/// input to the output element type. This extension is required to keep the -/// unsignedness information on the input for ukernels. If `transpose` is true, -/// the `linalgOp`'s indexing maps are transposed. -static Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp, - bool transpose, RewriterBase &rewriter, - SmallVectorImpl &ri, - ArrayRef elemTypes, int operandIdx) { - assert(linalgOp.getNumDpsInputs() == 2); - assert(linalgOp.getNumDpsInits() == 1); - auto cDims = linalg::inferContractionDims(linalgOp); - Location loc = linalgOp->getLoc(); - Value expandedValue = value; - // If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the - // operand is a vector and must be extended - if ((cDims->m.empty() && operandIdx != 1) || - (cDims->n.empty() && operandIdx != 0)) { - auto type = cast(value.getType()); - RankedTensorType newType = getExpandedType( - type, /*isBatched=*/!cDims->batch.empty(), - /*isTransposed=*/operandIdx == 2 && (transpose ^ cDims->n.empty()), ri); - expandedValue = - rewriter.create(loc, newType, value, ri); - } - if (elemTypes[operandIdx].isUnsignedInteger()) { - return createElementWiseExtUIOp(rewriter, expandedValue, loc, - elemTypes.back()); - } - return expandedValue; -} - static void transposeInPlace(MaterializeEncodingInfo &info) { // Vector cases: nothing to do. if (info.innerTileSizes.size() < 2) { @@ -297,75 +199,6 @@ FailureOr lowerUnsetEncodingToUnpackOp( encodingInfo->outerDimsPerm); } -static FailureOr lowerContractionOpWithEncoding( - RewriterBase &rewriter, linalg::LinalgOp linalgOp, ValueRange operands, - const MaterializeEncodingTypeConverter &typeConverter) { - if (!linalgOp.hasPureTensorSemantics()) - return failure(); - - auto inputs = linalgOp.getDpsInputOperands(); - auto outputs = linalgOp.getDpsInits(); - - auto lhsType = cast(inputs[0]->get().getType()); - auto rhsType = cast(inputs[1]->get().getType()); - auto resultType = cast(outputs[0].getType()); - auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType); - auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType); - auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType); - if (!lhsEncoding || !rhsEncoding || !resultEncoding) { - return failure(); - } - - if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS || - rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS || - resultEncoding.getOperandIndex().getValue() != - IREE::Encoding::MATMUL_RESULT) { - return failure(); - } - - FailureOr encodingInfo = - typeConverter.getEncodingInfo( - cast(linalgOp->getResultTypes()[0])); - - Operation *result; - if (failed(encodingInfo)) { - result = dropEncodingAndCloneOp(rewriter, linalgOp, - operands.take_front(inputs.size()), - operands.drop_front(inputs.size())); - } else { - bool transpose = - typeConverter.getTransposeNarrowN() && isNarrowNResult(resultEncoding); - SmallVector elemTypes = lhsEncoding.getElementTypesArray(); - SmallVector ri; - Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, rewriter, - ri, elemTypes, /*operandIdx=*/0); - Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, rewriter, - ri, elemTypes, /*operandIdx=*/1); - Value newResult = - getMmt4dOperand(operands[2], linalgOp, transpose, rewriter, ri, - elemTypes, /*operandIdx=*/2); - if (transpose) { - std::swap(newLhs, newRhs); - } - Type newResultType = newResult.getType(); - auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding); - if (cDims->batch.empty()) { - result = rewriter.create( - linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs}, - ValueRange{newResult}); - } else { - result = rewriter.create( - linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs}, - ValueRange{newResult}); - } - if (!ri.empty()) { - result = rewriter.create( - linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri); - } - } - return result; -} - /// Utility method to convert `tensor.empty` with encoding to a `tensor.empty` /// of the materialized type. static FailureOr @@ -901,8 +734,17 @@ class MaterializeContractionOp auto converter = static_cast( this->getTypeConverter()); + // TODO(hanchung): This is a transition state for moving the implementation + // details to backend attributes. We won't need the function type argument + // after all the backends that support encodings implement the attribute. + auto getEncodingInfoWrapper = + [&](RankedTensorType type) -> FailureOr { + return converter->getEncodingInfo(type); + }; FailureOr convertedOp = - lowerContractionOpWithEncoding(rewriter, op, operands, *converter); + IREE::Codegen::lowerContractionOpWithEncoding( + rewriter, op, operands, converter->getTransposeNarrowN(), + getEncodingInfoWrapper); if (failed(convertedOp)) { return failure(); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h index c41d581bf765..c4ad11e3da73 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h @@ -10,6 +10,7 @@ #include #include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LLVM.h" namespace mlir::iree_compiler::IREE::Codegen { @@ -89,5 +90,8 @@ struct MaterializeEncodingInfo { std::optional swizzle; }; +using ResolveEncodingInfoFn = + std::function(RankedTensorType type)>; + } // namespace mlir::iree_compiler::IREE::Codegen #endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_IR_IREECODEGENTYPES_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/BUILD.bazel index 2cca489b7102..a155423dcf95 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/BUILD.bazel @@ -22,8 +22,12 @@ iree_compiler_cc_library( ], deps = [ "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", + "//compiler/src/iree/compiler/Dialect/Encoding/IR", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:TensorDialect", ], ) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/CMakeLists.txt index cbfa9245cc88..bf4a0ed7f073 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/CMakeLists.txt @@ -19,8 +19,12 @@ iree_cc_library( "Utils.cpp" DEPS LLVMSupport + MLIRArithDialect MLIRIR + MLIRLinalgDialect + MLIRTensorDialect iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect + iree::compiler::Dialect::Encoding::IR PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp index 7b1e57480a20..4a12f7013417 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp @@ -6,6 +6,10 @@ #include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h" #include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -251,4 +255,213 @@ getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape) { return result; } +MaterializeEncodingInfo +getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK) { + MaterializeEncodingInfo encodingInfo; + auto cDims = getEncodingContractionDims(encoding); + // The following expects M, N, K, and Batch sizes of at most 1 for now + assert(cDims->m.size() <= 1 && cDims->n.size() <= 1 && cDims->k.size() == 1 && + cDims->batch.size() <= 1 && + "Expected at most one M, N, K, and Batch dimension"); + std::optional batchDim = + cDims->batch.empty() ? std::nullopt + : encoding.mapDimToOperandIndex(cDims->batch[0]); + std::optional mDim = + cDims->m.empty() ? std::nullopt + : encoding.mapDimToOperandIndex(cDims->m[0]); + std::optional nDim = + cDims->n.empty() ? std::nullopt + : encoding.mapDimToOperandIndex(cDims->n[0]); + std::optional kDim = encoding.mapDimToOperandIndex(cDims->k[0]); + if (batchDim.has_value()) { + encodingInfo.outerDimsPerm.push_back(batchDim.value()); + } + if (mDim.has_value()) { + encodingInfo.outerDimsPerm.push_back(mDim.value()); + encodingInfo.innerDimsPos.push_back(mDim.value()); + encodingInfo.innerTileSizes.push_back(tileMxNxK.M); + } + if (nDim.has_value()) { + encodingInfo.outerDimsPerm.push_back(nDim.value()); + encodingInfo.innerDimsPos.push_back(nDim.value()); + encodingInfo.innerTileSizes.push_back(tileMxNxK.N); + } + if (kDim.has_value()) { + encodingInfo.outerDimsPerm.push_back(kDim.value()); + encodingInfo.innerDimsPos.push_back(kDim.value()); + encodingInfo.innerTileSizes.push_back(tileMxNxK.K); + } + return encodingInfo; +} + +static RankedTensorType dropEncoding(RankedTensorType type) { + return RankedTensorType::get(type.getShape(), type.getElementType()); +} + +static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op, + ValueRange convertedInputOperands, + ValueRange convertedOutputOperands) { + SmallVector operands; + operands.append(convertedInputOperands.begin(), convertedInputOperands.end()); + operands.append(convertedOutputOperands.begin(), + convertedOutputOperands.end()); + return mlir::clone(builder, op, + {dropEncoding(cast( + convertedOutputOperands[0].getType()))}, + operands); +} + +static RankedTensorType +getExpandedType(RankedTensorType type, bool isBatched, bool isTransposed, + SmallVectorImpl &ri) { + if (!isBatched) { + ri.assign({{0, 1}, {2, 3}}); + if (!isTransposed) { + return RankedTensorType::get( + {1, type.getDimSize(0), 1, type.getDimSize(1)}, + type.getElementType()); + } + return RankedTensorType::get({type.getDimSize(0), 1, type.getDimSize(1), 1}, + type.getElementType()); + } + + ri.assign({{0}, {1, 2}, {3, 4}}); + if (!isTransposed) { + return RankedTensorType::get( + {type.getDimSize(0), 1, type.getDimSize(1), 1, type.getDimSize(2)}, + type.getElementType()); + } + return RankedTensorType::get( + {type.getDimSize(0), type.getDimSize(1), 1, type.getDimSize(2), 1}, + type.getElementType()); +} + +/// Given an input Value and a desired output element type, create and return +/// an element-wise linalg::GenericOp that extends the input Value to the +/// output element type. +static Value createElementWiseExtUIOp(OpBuilder &builder, Value input, + Location loc, Type outElemType) { + auto inputType = cast(input.getType()); + SmallVector maps( + 2, builder.getMultiDimIdentityMap(inputType.getRank())); + SmallVector iteratorTypes(inputType.getRank(), + utils::IteratorType::parallel); + auto castedType = inputType.clone(outElemType); + SmallVector inputMixedSizes = + tensor::getMixedSizes(builder, loc, input); + Value init = + builder.create(loc, inputMixedSizes, outElemType); + return builder + .create( + loc, castedType, input, init, maps, iteratorTypes, + [&](OpBuilder &b, Location nestedLoc, ValueRange args) { + Value castRes = + b.create(nestedLoc, outElemType, args[0]) + ->getResult(0); + b.create(nestedLoc, castRes); + }) + .getResult(0); +} + +/// If needed, expand and the input Value, and return the resulting input with +/// the canonical mmt4d input shape. If the input element type is unsigned, +/// create a producer Linalg::GenericOp on the input that unsigned extends the +/// input to the output element type. This extension is required to keep the +/// unsignedness information on the input for ukernels. If `transpose` is true, +/// the `linalgOp`'s indexing maps are transposed. +static Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp, + bool transpose, OpBuilder &builder, + SmallVectorImpl &ri, + ArrayRef elemTypes, int operandIdx) { + assert(linalgOp.getNumDpsInputs() == 2); + assert(linalgOp.getNumDpsInits() == 1); + auto cDims = linalg::inferContractionDims(linalgOp); + Location loc = linalgOp->getLoc(); + Value expandedValue = value; + // If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the + // operand is a vector and must be extended + if ((cDims->m.empty() && operandIdx != 1) || + (cDims->n.empty() && operandIdx != 0)) { + auto type = cast(value.getType()); + RankedTensorType newType = getExpandedType( + type, /*isBatched=*/!cDims->batch.empty(), + /*isTransposed=*/operandIdx == 2 && (transpose ^ cDims->n.empty()), ri); + expandedValue = + builder.create(loc, newType, value, ri); + } + if (elemTypes[operandIdx].isUnsignedInteger()) { + return createElementWiseExtUIOp(builder, expandedValue, loc, + elemTypes.back()); + } + return expandedValue; +} + +FailureOr +lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp, + ValueRange operands, bool transposeNarrowN, + ResolveEncodingInfoFn getEncodingInfo) { + if (!linalgOp.hasPureTensorSemantics()) { + return failure(); + } + + auto inputs = linalgOp.getDpsInputOperands(); + auto outputs = linalgOp.getDpsInits(); + + auto lhsType = cast(inputs[0]->get().getType()); + auto rhsType = cast(inputs[1]->get().getType()); + auto resultType = cast(outputs[0].getType()); + auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType); + auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType); + auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType); + if (!lhsEncoding || !rhsEncoding || !resultEncoding) { + return failure(); + } + + if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS || + rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS || + resultEncoding.getOperandIndex().getValue() != + IREE::Encoding::MATMUL_RESULT) { + return failure(); + } + + FailureOr encodingInfo = + getEncodingInfo(cast(linalgOp->getResultTypes()[0])); + + Operation *result; + if (failed(encodingInfo)) { + result = dropEncodingAndCloneOp(builder, linalgOp, + operands.take_front(inputs.size()), + operands.drop_front(inputs.size())); + } else { + bool transpose = transposeNarrowN && isNarrowNResult(resultEncoding); + SmallVector elemTypes = lhsEncoding.getElementTypesArray(); + SmallVector ri; + Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, builder, + ri, elemTypes, /*operandIdx=*/0); + Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, builder, + ri, elemTypes, /*operandIdx=*/1); + Value newResult = getMmt4dOperand(operands[2], linalgOp, transpose, builder, + ri, elemTypes, /*operandIdx=*/2); + if (transpose) { + std::swap(newLhs, newRhs); + } + Type newResultType = newResult.getType(); + auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding); + if (cDims->batch.empty()) { + result = builder.create(linalgOp.getLoc(), newResultType, + ValueRange{newLhs, newRhs}, + ValueRange{newResult}); + } else { + result = builder.create( + linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs}, + ValueRange{newResult}); + } + if (!ri.empty()) { + result = builder.create( + linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri); + } + } + return result; +} + } // namespace mlir::iree_compiler::IREE::Codegen diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h index d19096ec41f7..b1997c1b91fd 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h @@ -8,6 +8,7 @@ #define IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_UTILS_H_ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h" +#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" @@ -60,6 +61,24 @@ deserializeEncodingInfo(DictionaryAttr attr); SmallVector getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape); +struct TileMxNxK { + int64_t M = 1; + int64_t N = 1; + int64_t K = 1; +}; + +MaterializeEncodingInfo +getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK); + +//===----------------------------------------------------------------------===// +// Operation Lowering Utilities. +//===----------------------------------------------------------------------===// + +FailureOr +lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp, + ValueRange operands, bool transposeNarrowN, + ResolveEncodingInfoFn getEncodingInfo); + } // namespace mlir::iree_compiler::IREE::Codegen #endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_UTILS_H_ diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp index 333145d0e8c3..cd023d0ec92b 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -130,6 +130,14 @@ MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) { return {}; } +bool isNarrowNResult(EncodingAttr encoding) { + if (encoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RESULT) { + return false; + } + + return IREE::Encoding::getMatmulNarrowDim(encoding).isN(); +} + EncodingAttr getEncodingAttr(RankedTensorType type) { return dyn_cast_or_null(type.getEncoding()); } diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h index 62a3efdcb98a..e4354c51a298 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h @@ -86,6 +86,10 @@ MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp, /// value. MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding); +// Returns true if `encoding` represents a narrow-N matmul RESULT, e.g. the +// result of a matvec. +bool isNarrowNResult(EncodingAttr encoding); + } // namespace mlir::iree_compiler::IREE::Encoding #endif // IREE_COMPILER_DIALECT_ENCODING_IR_ENCODINGTYPES_H_