Skip to content

Commit

Permalink
[DT][CPU] Implement CPUEncodingLayoutAttribute. (iree-org#19419)
Browse files Browse the repository at this point in the history
The commit implements the CPU encoding layout attribute. To reuse the
utilities for querying CPU features and arch, it adapts the methods to
take Attribute type.

On the materialization pass, it creates the CPUEncodingLayoutAttribute
when the target is llvmcpu.

Now the fallback solution returns identity encoding info container,
which behaves like dropping encodings. Thus, the lowering of contraction
needs to take it into account.

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW authored Dec 10, 2024
1 parent 5dba933 commit 415a0f6
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 299 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
Expand All @@ -30,6 +31,8 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "cpu-materialize-encoding"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

namespace mlir::iree_compiler {

Expand All @@ -40,275 +43,10 @@ using IREE::Codegen::TileMxNxK;
#define GEN_PASS_DEF_CPUMATERIALIZEHOSTENCODINGPASS
#include "iree/compiler/Codegen/Common/CPU/Passes.h.inc"

// Enumerate tile sizes to choose from on riscv32.
// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
// are handled by transposition in IREE::Codegen::chooseMatmulTile.
static SmallVector<TileMxNxK>
enumerateMatmulTileRiscv32(IREE::HAL::ExecutableTargetAttr target) {
if (hasUkernel(target)) {
return {
TileMxNxK{8, 8, 4}, // Some reasonable tile shape.
TileMxNxK{4, 8, 4}, // Truncation of the above.
TileMxNxK{2, 8, 4}, // Truncation of the above.
TileMxNxK{1, 8, 4}, // Truncation of the above.
};
}
// Fallback - no architecture-optimized tile size for this case.
return {};
}

// Enumerate tile sizes to choose from on arm64.
// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
// are handled by transposition in IREE::Codegen::chooseMatmulTile.
static SmallVector<TileMxNxK>
enumerateMatmulTileArm64(TypeRange elementTypes,
IREE::HAL::ExecutableTargetAttr target) {
// Data-tiling for SVE is not implemented yet.
if (hasFeature(target, "+sve") || hasFeature(target, "+sve2")) {
return {};
}

assert(elementTypes.size() == 3);
Type lhs = elementTypes[0];
Type rhs = elementTypes[1];
Type out = elementTypes[2];

if (out.isF32() || out.isF16() || out.isBF16()) {
if (lhs.isBF16() && rhs.isBF16() && (out.isBF16() || out.isF32()) &&
hasFeature(target, "+bf16")) {
return {
TileMxNxK{8, 8, 4}, // Aim to use BFMMLA.
TileMxNxK{4, 8, 4}, // Truncation of the above.
TileMxNxK{2, 8, 4}, // Truncation of the above.
TileMxNxK{1, 8, 4}, // Truncation of the above.
};
}
if (isa<FloatType>(lhs) && isa<FloatType>(rhs)) {
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
// reconsider when taking advantage of native f16/bf16 arithmetic when the
// accumulator itself is f16/bf16, as we could typically have a 2x wider
// tile in that case. However, on current CPUs, the existing tiles seem
// wide enough already to approach peak performance.
return {
TileMxNxK{8, 8, 1}, // Aim to use FMLA or FMLAL.
TileMxNxK{4, 8, 1}, // Truncation of the above.
TileMxNxK{2, 8, 1}, // Truncation of the above.
TileMxNxK{1, 8, 1}, // Truncation of the above.
};
}
}

if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) &&
out.isSignlessInteger(32)) {
if (hasFeature(target, "+i8mm")) {
return {
TileMxNxK{8, 8, 8}, // Aim to use SMMLA.
TileMxNxK{4, 8, 8}, // Truncation of the above.
TileMxNxK{2, 8, 8}, // Truncation of the above.
TileMxNxK{1, 8, 8}, // Truncation of the above.
};
}
if (hasFeature(target, "+dotprod")) {
return {
TileMxNxK{8, 8, 4}, // Aim to use SDOT.
TileMxNxK{4, 8, 4}, // Truncation of the above.
TileMxNxK{2, 8, 4}, // Truncation of the above.
TileMxNxK{1, 8, 4}, // Truncation of the above.
};
}
}

if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(4) &&
out.isSignlessInteger(32)) {
if (hasFeature(target, "+i8mm")) {
return {
TileMxNxK{4, 8, 16},
TileMxNxK{2, 8, 16},
TileMxNxK{1, 8, 16},
};
}
if (hasFeature(target, "+dotprod")) {
return {
TileMxNxK{8, 8, 8},
TileMxNxK{4, 8, 8},
TileMxNxK{2, 8, 8},
TileMxNxK{1, 8, 8},
};
}
return {
TileMxNxK{4, 16, 2},
TileMxNxK{2, 16, 2},
TileMxNxK{1, 16, 2},
};
}

// Fallback - no architecture-optimized tile size for this case.
return {};
}

// Enumerate tile sizes to choose from on x86-64.
// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
// are handled by transposition in IREE::Codegen::chooseMatmulTile.
static SmallVector<TileMxNxK>
enumerateMatmulTileX86_64(TypeRange elementTypes,
IREE::HAL::ExecutableTargetAttr target) {
assert(elementTypes.size() == 3);
Type lhs = elementTypes[0];
Type rhs = elementTypes[1];
Type out = elementTypes[2];

if (out.isF32() || out.isF16() || out.isBF16()) {
if (lhs.isBF16() && rhs.isBF16() && (out.isBF16() || out.isF32())) {
if (hasFeature(target, "+avx512bf16")) {
return {
TileMxNxK{16, 16, 2}, // Aim to use VDPBF16PS (zmm).
TileMxNxK{8, 16, 2}, // Truncation of the above.
TileMxNxK{4, 16, 2}, // Truncation of the above.
TileMxNxK{2, 16, 2}, // Truncation of the above.
TileMxNxK{1, 16, 2}, // Truncation of the above.
};
}
}
if (isa<FloatType>(lhs) && isa<FloatType>(rhs)) {
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
// reconsider when taking advantage of native f16/bf16 arithmetic when the
// accumulator itself is f16/bf16.
if (hasFeature(target, "+avx512f")) {
return {
TileMxNxK{16, 16, 1}, // Aim to use VFMADD* (zmm).
TileMxNxK{8, 16, 1}, // Truncation of the above.
TileMxNxK{4, 16, 1}, // Truncation of the above.
TileMxNxK{2, 16, 1}, // Truncation of the above.
TileMxNxK{1, 16, 1}, // Truncation of the above.
};
}
if (hasFeature(target, "+avx")) {
// Note: for good performance, most +avx users will also want to add
// +fma, but that's a local instruction selection detail and the tile
// layout is unaffected, as there are enough registers even with the
// need for intermediate product registers when +fma is not used.
return {
TileMxNxK{8, 8, 1}, // Aim to use VFMADD* (ymm).
TileMxNxK{4, 8, 1}, // Truncation of the above.
TileMxNxK{2, 8, 1}, // Truncation of the above.
TileMxNxK{1, 8, 1}, // Truncation of the above.
};
}
// SSE fallback.
return {
TileMxNxK{8, 4, 1}, // Aim to use MULPS/ADDPS (xmm).
TileMxNxK{4, 4, 1}, // Truncation of the above.
TileMxNxK{2, 4, 1}, // Truncation of the above.
TileMxNxK{1, 4, 1}, // Truncation of the above.
};
}
}

if (out.isSignlessInteger(32) &&
((lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8)) ||
(lhs.isSignlessInteger(16) && rhs.isSignlessInteger(16)))) {
if (hasFeature(target, "+avx512vnni")) {
// This is the same tile size as with VPMADDWD as the only difference
// is that VPDPWSSD accumulates. VPDPBUSD would call for {16, 16, 4} but
// we can't easily use it because of its unsigned*signed semantics.
return {
TileMxNxK{16, 16, 2}, // Aim to use VPDPWSSD (zmm).
TileMxNxK{8, 16, 2}, // Truncation of the above.
TileMxNxK{4, 16, 2}, // Truncation of the above.
TileMxNxK{2, 16, 2}, // Truncation of the above.
TileMxNxK{1, 16, 2}, // Truncation of the above.
};
}
if (hasFeature(target, "+avx512bw")) {
return {
TileMxNxK{16, 16, 2}, // Aim to use VPMADDWD (zmm).
TileMxNxK{8, 16, 2}, // Truncation of the above.
TileMxNxK{4, 16, 2}, // Truncation of the above.
TileMxNxK{2, 16, 2}, // Truncation of the above.
TileMxNxK{1, 16, 2}, // Truncation of the above.
};
}
if (hasFeature(target, "+avx2")) {
return {
TileMxNxK{8, 8, 2}, // Aim to use VPMADDWD (ymm).
TileMxNxK{4, 8, 2}, // Truncation of the above.
TileMxNxK{2, 8, 2}, // Truncation of the above.
TileMxNxK{1, 8, 2}, // Truncation of the above.
};
}
// SSE fallback.
return {
TileMxNxK{8, 4, 2}, // Aim to use PMADDWD (xmm).
TileMxNxK{4, 4, 2}, // Truncation of the above.
TileMxNxK{2, 4, 2}, // Truncation of the above.
TileMxNxK{1, 4, 2}, // Truncation of the above.
};
}

if (out.isSignlessInteger(32) && lhs.isSignlessInteger(16) &&
rhs.isUnsignedInteger(4)) {
// Experimental s16u4s32 case. Focusing only on the vecmat case for now.
if (hasFeature(target, "+avx512vnni")) {
return {
TileMxNxK{1, 32, 8}, // Aim to use VPDPBUSD (zmm).
};
}
}

// Fallback - no architecture-optimized tile size for this case.
return {};
}

static SmallVector<TileMxNxK>
enumerateMatmulTileMxNxK(IREE::Encoding::EncodingAttr encoding,
IREE::HAL::ExecutableTargetAttr target) {
// We only know about contractions with {Batch, M, N, K} <= 1 at the moment.
auto cDims = getEncodingContractionDims(encoding);
if (failed(cDims) || cDims->batch.size() > 1 || cDims->m.size() > 1 ||
cDims->n.size() > 1 || cDims->k.size() > 1) {
return {};
}
// Enumerate available tile shapes for the given encoding and target.
SmallVector<Type> elementTypes = encoding.getElementTypesArray();
if (isAArch64(target)) {
return enumerateMatmulTileArm64(elementTypes, target);
}
if (isX86_64(target)) {
return enumerateMatmulTileX86_64(elementTypes, target);
}
if (isRISCV32(target)) {
return enumerateMatmulTileRiscv32(target);
}
return {};
}

static FailureOr<MaterializeEncodingInfo>
materializeEncodingForTarget(RankedTensorType tensorType,
IREE::HAL::ExecutableTargetAttr targetAttr) {
auto encoding =
dyn_cast_or_null<IREE::Encoding::EncodingAttr>(tensorType.getEncoding());
if (!encoding) {
return failure();
}

SmallVector<TileMxNxK> enumeratedTileMxNxK =
enumerateMatmulTileMxNxK(encoding, targetAttr);
if (enumeratedTileMxNxK.empty()) {
return failure();
}
auto narrowDim = IREE::Encoding::getMatmulNarrowDim(encoding);
// Choose a final matmul TileMxNxK from the above-enumarated tile shapes,
// taking narrow dimensions into account.
TileMxNxK chosenTileMxNxK = IREE::Codegen::chooseMatmulTile(
enumeratedTileMxNxK, narrowDim, encoding.getRoundDimsToArray());

// Map the matmul TileMxNxK to an actual tile shape for the tensor at hand,
// based on its operand index in the matmul.
return IREE::Codegen::getEncodingInfoForMatmul(encoding, chosenTileMxNxK);
return failure();
}

static FailureOr<MaterializeEncodingValueInfo>
Expand Down Expand Up @@ -342,11 +80,16 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp,
// 2. We use ukernels, and this allows writing 2x fewer narrow ukernels.
// 3. Heuristics for cache-friendly dispatch tiling can get complex on CPU,
// so it is nice that they have fewer narrow cases to consider.
DictionaryAttr targetConfig = targetAttr.getConfiguration();
IREE::Codegen::LayoutAttrInterface layoutAttr;
if (isVMVXBackend(targetAttr)) {
LDBG("Select VMVXEncodingLayoutAttr attribute as the layout attribute.");
layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
IREE::CPU::VMVXEncodingLayoutAttr::get(ctx, targetConfig));
} else {
LDBG("Select CPUEncodingLayoutAttr attribute as the layout attribute.");
layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
IREE::CPU::VMVXEncodingLayoutAttr::get(ctx,
targetAttr.getConfiguration()));
IREE::CPU::CPUEncodingLayoutAttr::get(ctx, targetConfig));
}
MaterializeEncodingTypeConverter typeConverter(
materializeEncodingForTarget, targetAttr, /*transposeNarrowN=*/true,
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
transposeNarrowN ? transposeIfNarrowNResult(type) : type;
FailureOr<MaterializeEncodingInfo> maybeEncodingInfo =
getEncodingInfo(tensorType);
if (failed(maybeEncodingInfo)) {
if (failed(maybeEncodingInfo) ||
IREE::Codegen::isIdentityLayout(maybeEncodingInfo.value())) {
return dropEncoding(type);
}
auto encodingInfo = *maybeEncodingInfo;
Expand Down
19 changes: 19 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@ include "mlir/IR/AttrTypeBase.td"
// iree_cpu.encoding_layout_attr
//===----------------------------------------------------------------------===//

def IREECPU_CPUEncodingLayoutAttr :
AttrDef<IREECPU_Dialect, "CPUEncodingLayout"> {
let mnemonic = "cpu_encoding_layout";
let summary = "The encoding layout attribute for CPU backends.";
let description = [{
This attribute can implement any layout interface methods for data-tiling,
e.g., Codegen::LayoutAttrInterface, etc. They are implemented through
external model mechanism See the implementation in
compiler/Codegen/ExternalInterfaces/*.
}];

let assemblyFormat = "`<` struct(params) `>`";

let parameters = (ins
OptionalParameter<"DictionaryAttr", "Executable target configuration. It is "
"expected to be used in a pass scope, but not the final IR output.">:$configuration
);
}

def IREECPU_VMVXEncodingLayoutAttr :
AttrDef<IREECPU_Dialect, "VMVXEncodingLayout"> {
let mnemonic = "vmvx_encoding_layout";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp,
getEncodingInfo(cast<RankedTensorType>(linalgOp->getResultTypes()[0]));

Operation *result;
if (failed(encodingInfo)) {
if (failed(encodingInfo) || isIdentityLayout(encodingInfo.value())) {
result = dropEncodingAndCloneOp(builder, linalgOp,
operands.take_front(inputs.size()),
operands.drop_front(inputs.size()));
Expand Down
Loading

0 comments on commit 415a0f6

Please sign in to comment.