Skip to content

Commit

Permalink
[DT][NFC] Internalize transposeNarrowN logic to LayoutAttrInterface I…
Browse files Browse the repository at this point in the history
…mpl (iree-org#19453)

Whether applying transposition from narrow-N to narrow-M is backend
implementation details, and we do not need to expose it to the type
converter. The encoding itself has enough information, like indexing
maps, narrow dimensions, etc., to infer the shapes and encoding info.
Instead of updating the RankedTensorType and the attached encoding in
type converter, we can just cook the logic in `getEncodingInfo` methods.
From the encoding, we know that whether it is narrow-N case, and we can
update the MaterializeEncodingInfo correspondingly. The type converter
can infer the transposed tensor type from it. Thus, we can simplify the
logic in the type conversion.

The documentation of `transposeNarrowN` is moved to
`[CPU|GPU]EncodingExternalModels.cpp` because all the implementation
locates at the files.

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW authored Dec 16, 2024
1 parent dc29ee7 commit 67a05a4
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,6 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp,
IREE::HAL::ExecutableTargetAttr targetAttr) {
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet materializeEncodingPattern(ctx);
// On CPU, we use transposeNarrowN=true for a combination of reasons:
// 1. As linalg.matmul materializes into linalg.mmt4d, which has a transposed
// RHS and therefore LHS<->RHS symmetry, transposeNarrowN is easy to
// implement at that level.
// 2. We use ukernels, and this allows writing 2x fewer narrow ukernels.
// 3. Heuristics for cache-friendly dispatch tiling can get complex on CPU,
// so it is nice that they have fewer narrow cases to consider.
DictionaryAttr targetConfig = targetAttr.getConfiguration();
IREE::Codegen::LayoutAttrInterface layoutAttr;
if (isVMVXBackend(targetAttr)) {
Expand All @@ -85,8 +78,7 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp,
layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
IREE::CPU::CPUEncodingLayoutAttr::get(ctx, targetConfig));
}
MaterializeEncodingTypeConverter typeConverter(
/*transposeNarrowN=*/true, layoutAttr);
MaterializeEncodingTypeConverter typeConverter(layoutAttr);
MaterializeEncodingConversionTarget target(*ctx);
auto materializeEncodingValueFn = getMaterializeEncodingValueFn(targetAttr);
populateMaterializeEncodingIntoPackUnPackPatterns(
Expand Down
77 changes: 4 additions & 73 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,76 +20,9 @@ using IREE::Encoding::EncodingAttr;
using IREE::Encoding::getEncodingAttr;
using IREE::Encoding::getEncodingContractionDims;

// If tensorType has the encoding of a matmul RESULT with narrow N, returns
// the transposed type. Otherwise, just returns tensorType.
static RankedTensorType transposeIfNarrowNResult(RankedTensorType tensorType) {
auto encoding =
llvm::dyn_cast_or_null<EncodingAttr>(tensorType.getEncoding());
if (!encoding) {
return tensorType;
}
if (!isNarrowNResult(encoding)) {
return tensorType;
}
SmallVector<int64_t> newOriginalShape(tensorType.getShape());
auto userIndexingMaps = encoding.getUserIndexingMaps();
SmallVector<AffineMap> maps;
for (auto a : userIndexingMaps) {
maps.push_back(cast<AffineMapAttr>(a).getAffineMap());
}
auto cDims = linalg::inferContractionDims(maps);
SmallVector<int64_t> newShape(tensorType.getShape());
SmallVector<int64_t> permIndices(maps[0].getNumDims());
std::iota(std::begin(permIndices), std::end(permIndices), 0);
// Matrix case: there are both M and N dimensions. Transposing means swapping
// them.
if (cDims->m.size() == 1 && cDims->n.size() == 1) {
int m = cDims->m[0];
int n = cDims->n[0];
std::swap(permIndices[m], permIndices[n]);
std::optional<unsigned> mDim = encoding.mapDimToOperandIndex(m);
std::optional<unsigned> nDim = encoding.mapDimToOperandIndex(n);
if (mDim.has_value() && nDim.has_value()) {
std::swap(newShape[mDim.value()], newShape[nDim.value()]);
std::swap(newOriginalShape[mDim.value()], newOriginalShape[nDim.value()]);
}
}
// Vector case: there is no N dimension to swap the M dimension with. We
// swap the maps themselves.
if (cDims->n.empty()) {
std::swap(maps[0], maps[1]);
}

SmallVector<int64_t> newRoundDimsTo(encoding.getRoundDimsToArray());
assert(newRoundDimsTo.size() == 0 || newRoundDimsTo.size() >= 3);
if (newRoundDimsTo.size() != 0) {
std::swap(newRoundDimsTo[newRoundDimsTo.size() - 3],
newRoundDimsTo[newRoundDimsTo.size() - 2]);
}
auto context = tensorType.getContext();
AffineMap permutation = AffineMap::getPermutationMap(permIndices, context);
for (auto &map : maps) {
map = map.compose(permutation);
}
auto elemType = tensorType.getElementType();
auto operandIndex = encoding.getOperandIndex().getInt();

// TODO(#17718): Handle the broadcast map for transpose cases. It is on the
// experimental path, so it is not clear what needs to be done here. For now
// just use the original map for the new encoding.
std::optional<AffineMap> newBcastMap;
if (encoding.getBcastMap()) {
newBcastMap = encoding.getBcastMap().getValue();
}
auto newEncoding = IREE::Encoding::EncodingAttr::get(
context, operandIndex, encoding.getOpType().getValue(),
encoding.getElementTypesArray(), maps, newBcastMap, newRoundDimsTo);
return RankedTensorType::get(newShape, elemType, newEncoding);
}

MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
bool transposeNarrowN, IREE::Codegen::LayoutAttrInterface layoutAttr)
: transposeNarrowN(transposeNarrowN), layoutAttr(layoutAttr) {
IREE::Codegen::LayoutAttrInterface layoutAttr)
: layoutAttr(layoutAttr) {
addConversion([](IntegerType intType) { return intType; });
addConversion([](IndexType indexType) { return indexType; });
addConversion([](FloatType floatType) { return floatType; });
Expand All @@ -98,14 +31,12 @@ MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
// For a given tensor type with an encoding, return the materialized
// type to use for it. If no encoding is set, then return the tensor type
// itself.
RankedTensorType tensorType =
transposeNarrowN ? transposeIfNarrowNResult(type) : type;
MaterializeEncodingInfo encodingInfo = getEncodingInfo(tensorType);
MaterializeEncodingInfo encodingInfo = getEncodingInfo(type);
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return dropEncoding(type);
}
auto packedType = cast<RankedTensorType>(tensor::PackOp::inferPackedType(
tensorType, encodingInfo.innerTileSizes, encodingInfo.innerDimsPos,
type, encodingInfo.innerTileSizes, encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm));

// There is no swizzle, we are already done. Typically the case on CPU.
Expand Down
7 changes: 1 addition & 6 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using MaterializeEncodingValueFn =
class MaterializeEncodingTypeConverter : public TypeConverter {
public:
MaterializeEncodingTypeConverter(
bool transposeNarrowN, IREE::Codegen::LayoutAttrInterface layoutAttr);
IREE::Codegen::LayoutAttrInterface layoutAttr);

const IREE::Codegen::LayoutAttrInterface &getLayoutAttr() const {
return layoutAttr;
Expand All @@ -47,12 +47,7 @@ class MaterializeEncodingTypeConverter : public TypeConverter {
return layoutAttr.getEncodingInfo(type);
}

bool getTransposeNarrowN() const { return transposeNarrowN; }

private:
bool transposeNarrowN = false;
// TODO(hanchung): Move the logic that takes `transposeNarrowN` into account
// to their own attribute implementation.
const IREE::Codegen::LayoutAttrInterface layoutAttr;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,22 +271,13 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp,
MLIRContext *ctx = funcOp.getContext();
{
RewritePatternSet patterns(ctx);
// On GPU, we use transposeNarrowN=false for a combination of reasons:
// 1. As linalg.matmul materializes into iree_gpu.multi_mma, which inherits
// its semantics from the wrapped intrinsic, we can't rely on any kind of
// LHS<->RHS symmetry.
// 2. We do not currently use ukernels, which would be one of the main areas
// to benefit from transposeNarrowN.
// 3. Heuristics for cache-friendly dispatch tiling are internal to the GPU
// runtime, so we don't need a simplification at that level either.
IREE::GPU::TargetAttr gpuTargetAttr;
if (targetAttr) {
gpuTargetAttr = getGPUTargetAttr(targetAttr);
} else {
gpuTargetAttr = getCLGPUTarget(ctx);
}
MaterializeEncodingTypeConverter typeConverter(
/*transposeNarrowN=*/false,
cast<IREE::Codegen::LayoutAttrInterface>(
IREE::GPU::GPUEncodingLayoutAttr::get(ctx, gpuTargetAttr)));
MaterializeEncodingConversionTarget target(*ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ struct MaterializeEncodingIntoNopPass final

RewritePatternSet materializeEncodingPattern(context);
MaterializeEncodingTypeConverter typeConverter(
/*transposeNarrowN=*/false,
IREE::Codegen::EncodingNopLayoutAttr::get(context));
MaterializeEncodingConversionTarget target(*context);
populateMaterializeEncodingIntoPackUnPackPatterns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,6 @@ getInnerTileSizesOfr(OpBuilder &rewriter, Location loc,
return result;
}

static void transposeInPlace(MaterializeEncodingInfo &info) {
// Vector cases: nothing to do.
if (info.innerTileSizes.size() < 2) {
return;
}
// Not a vector case, so all three arrays in `info` have size at least 2,
// outerDimsPerm may have size 3 if there is a batch dimension, but in all
// cases, the last 2 entries of each array are M and N, not batch.
auto transpose = [](SmallVector<int64_t> &a) {
std::swap(a[a.size() - 2], a[a.size() - 1]);
};
transpose(info.innerDimsPos);
transpose(info.innerTileSizes);
transpose(info.outerDimsPerm);
}

//===---------------------------------------------------------------------===//
// Methods to convert `set_encoding` and `unset_encoding` operations
// to `pack` and `unpack` operations respectively.
Expand All @@ -139,9 +123,6 @@ FailureOr<Value> lowerSetEncodingOpToPackOp(
if (!encoding) {
return failure();
}
if (typeConverter.getTransposeNarrowN() && isNarrowNResult(encoding)) {
transposeInPlace(encodingInfo);
}

// Create `tensor.empty` operation for the result of the pack operation.
Location loc = encodingOp.getLoc();
Expand Down Expand Up @@ -180,10 +161,6 @@ FailureOr<Value> lowerUnsetEncodingToUnpackOp(
return packedValue;
}

auto encoding = IREE::Encoding::getEncodingAttr(sourceType);
if (typeConverter.getTransposeNarrowN() && isNarrowNResult(encoding)) {
transposeInPlace(encodingInfo);
}
// Create an `tensor.empty` for the result of the unpack operation.
Location loc = encodingOp.getLoc();
SmallVector<OpFoldResult> resultDims =
Expand Down Expand Up @@ -222,11 +199,6 @@ lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp,
.getOperation();
}

if (typeConverter.getTransposeNarrowN() &&
isNarrowNResult(IREE::Encoding::getEncodingAttr(emptyType))) {
transposeInPlace(encodingInfo);
}

FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
rewriter, loc, emptyType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
Expand Down Expand Up @@ -389,10 +361,6 @@ static FailureOr<SmallVector<OpFoldResult>> getPackedDimsForDispatchTensor(
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return failure();
}
if (typeConverter.getTransposeNarrowN() &&
isNarrowNResult(IREE::Encoding::getEncodingAttr(boundTensorType))) {
transposeInPlace(encodingInfo);
}

SmallVector<OpFoldResult> targetShape =
getMixedValues(boundTensorType.getShape(), dynamicDims, builder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- CPUEncodingExternalModels.cpp --------------------------------------===//
//
// This file implements the IREE::Codegen::LayoutAttrInterface for CPU backends
// and the VMVX backend. In these backends, we transpose narrow-N into narrow-M
// for a combination of reasons:
//
// 1. As linalg.matmul materializes into linalg.mmt4d, which has a transposed
// RHS and therefore LHS<->RHS symmetry, transposeNarrowN is easy to
// implement at that level.
// 2. We use ukernels, and this allows writing 2x fewer narrow ukernels.
// 3. Heuristics for cache-friendly dispatch tiling can get complex on CPU,
// so it is nice that they have fewer narrow cases to consider.
//
// This transposition is made easier by (and was all along part of the idea in)
// the RHS-transposition in mmt4d (the t in mmt4d), as generally with matrix
// multiplication
//
// B * Transpose(A) == Transpose( A * Transpose(B) )
//
// so in mmt4d terms
//
// mmt4d(B, A) == Transpose(mmt4d(A, B))
//
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.h"

Expand All @@ -12,6 +36,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"

Expand All @@ -28,6 +53,22 @@ namespace {
// Utilities.
//===----------------------------------------------------------------------===//

static void transposeInPlace(MaterializeEncodingInfo &info) {
// Vector cases: nothing to do.
if (info.innerTileSizes.size() < 2) {
return;
}
// Not a vector case, so all three arrays in `info` have size at least 2,
// outerDimsPerm may have size 3 if there is a batch dimension, but in all
// cases, the last 2 entries of each array are M and N, not batch.
auto transpose = [](SmallVector<int64_t> &a) {
std::swap(a[a.size() - 2], a[a.size() - 1]);
};
transpose(info.innerDimsPos);
transpose(info.innerTileSizes);
transpose(info.outerDimsPerm);
}

static RankedTensorType dropEncoding(RankedTensorType type) {
return RankedTensorType::get(type.getShape(), type.getElementType());
}
Expand Down Expand Up @@ -576,7 +617,11 @@ struct CPUDeviceEncodingLayoutAttrInterface
// taking narrow dimensions into account.
TileMxNxK chosenTileMxNxK = chooseMatmulTile(
enumeratedTileMxNxK, narrowDim, encoding.getRoundDimsToArray());
return getEncodingInfoForMatmul(encoding, chosenTileMxNxK);
info = getEncodingInfoForMatmul(encoding, chosenTileMxNxK);
if (Encoding::isNarrowNResult(encoding)) {
transposeInPlace(info);
}
return info;
}

Operation *lowerOp(Attribute attr, OpBuilder &b, Operation *op,
Expand Down Expand Up @@ -660,7 +705,11 @@ struct VMVXDeviceEncodingLayoutAttrInterface
// taking narrow dimensions into account.
TileMxNxK chosenTileMxNxK = chooseMatmulTile(
enumeratedTileMxNxK, narrowDim, encoding.getRoundDimsToArray());
return getEncodingInfoForMatmul(encoding, chosenTileMxNxK);
info = getEncodingInfoForMatmul(encoding, chosenTileMxNxK);
if (Encoding::isNarrowNResult(encoding)) {
transposeInPlace(info);
}
return info;
}

Operation *lowerOp(Attribute attr, OpBuilder &b, Operation *op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- GPUEncodingExternalModels.cpp --------------------------------------===//
//
// This file implements the IREE::Codegen::LayoutAttrInterface for GPU backends.
// Different from CPU backends, we do not tranpose narrow-N to narrow-M for a
// combination of reasons:
//
// 1. As linalg.matmul materializes into iree_gpu.multi_mma, which inherits
// its semantics from the wrapped intrinsic, we can't rely on any kind of
// LHS<->RHS symmetry.
// 2. We do not currently use ukernels, which would be one of the main areas
// to benefit from transposeNarrowN.
// 3. Heuristics for cache-friendly dispatch tiling are internal to the GPU
// runtime, so we don't need a simplification at that level either.
//
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.h"

Expand Down

0 comments on commit 67a05a4

Please sign in to comment.