From 2c53b4a07ae56dd8e1a7c2772911da0bf86cfce5 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 30 Jul 2024 15:22:04 -0700 Subject: [PATCH] Optimize `fp8` `linalg_ext.attention` by rework Q@K scaling (#18031) Moving where the Q@k scaling and allowing the max scale to carry of into the softmax sum achieves the same outcome for softmax while removing the rescaling. This does not affect precision as the original plan of applying a Q@K offset originally did. --- .../Transforms/AggregatedOpInterfaceImpl.cpp | 74 ++++++------------- .../test/decompose_online_attention.mlir | 20 +++-- 2 files changed, 32 insertions(+), 62 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp index 23ba79c2f665..b4e3d2e027ed 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp @@ -90,28 +90,19 @@ static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap, auto srcTy = cast(args[0].getType()); auto dstTy = cast(args[1].getType()); - // We clamp to the min / max of the floating point representation - double mnDbl = - APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/true) - .convertToDouble(); double mxDbl = APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false) .convertToDouble(); // Truncate to the `fp8` range so avoid nan values. - Value mn = builder.create( - loc, builder.getFloatAttr(srcTy, mnDbl)); Value mx = builder.create( loc, builder.getFloatAttr(srcTy, mxDbl)); Value gt = b.create(loc, arith::CmpFPredicate::OGT, args[0], mx); - Value lt = b.create(loc, arith::CmpFPredicate::OLT, - args[0], mn); Value sel0 = b.create(loc, gt, mx, args[0]); - Value sel1 = b.create(loc, lt, mn, sel0); // Convert scale to the same datatype as input. - Value trunc = convertScalarToDtype(b, loc, sel1, dstTy, + Value trunc = convertScalarToDtype(b, loc, sel0, dstTy, /*isUnsignedCast=*/false); b.create(loc, trunc); }); @@ -302,6 +293,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // SMap = QMap @ KMap Value emptyS = b.create(loc, sSizes, elementType); Value sZero = b.create(loc, b.getZeroAttr(elementType)); + Value s = b.create(loc, sZero, emptyS).getResult(0); s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s); @@ -323,11 +315,6 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap maxMap = getMaxMap(); Value newMax = reduce(b, loc, sMap, maxMap, s, oldMax); - // P = exp2(S - newMax) - // PMap = SMap - AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); - // norm = exp2(oldMax - newMax) // normMap = maxMap AffineMap normMap = getMaxMap(); @@ -337,6 +324,27 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap sumMap = getSumMap(); Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm); + // P = exp2(S - newMax) + // PMap = SMap + AffineMap pMap = sMap; + Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); + + // If we need to truncate to fp8 post softmax we apply a scaling to use the + // full fp8 range. We can do this with a offset as post `exp2` this equates + // to multiplying by a static value. We are able to do this as `max` and `sum` + // are scaled by the same value so the end result is the same. + if (isa(qETy) && qETy.getIntOrFloatBitWidth() == 8) { + auto fpTy = cast(qETy); + double mx = + APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false) + .convertToDouble(); + Value offset = + b.create(loc, b.getFloatAttr(elementType, mx)); + AffineMap scaleMap = AffineMap::get(/*dimCount=*/pMap.getNumInputs(), + /*symbolCount=*/0, getContext()); + p = scaleValueInPlace(b, loc, pMap, scaleMap, p, offset); + } + // newSum = normSum + rowSum(P) Value newSum = reduce(b, loc, pMap, sumMap, p, normSum); @@ -344,36 +352,8 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap accMap = getOutputMap(); // ---- Scale and truncate LHS to match RHS ---- - Value pScale; auto pETy = getElementTypeOrSelf(p.getType()); if (pETy != vETy && isa(vETy)) { - if (vETy.getIntOrFloatBitWidth() <= 8) { - SmallVector mSizes( - llvm::map_range(maxMap.getResults(), [&](AffineExpr dimExpr) { - return sizes[cast(dimExpr).getPosition()]; - })); - - auto fpTy = cast(vETy); - double largestDbl = - APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false) - .convertToDouble(); - - // We normalize p from [0, max] to [0, fp8.max] to guarantee we - // use the full `fp8` range, then renormlize post Softmax@V matmul - // to correct. - pScale = b.create( - loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / largestDbl)); - - // Compute the pre matmul scale to handle fp8 quantization: - Value pScaleInv = b.create( - loc, b.getFloatAttr(elementType, largestDbl / clAttentionSoftmaxMax)); - - AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(), - /*symbolCount=*/0, getContext()); - p = scaleValueInPlace(b, loc, pMap, scaleMap, p, pScaleInv); - norm = scaleValueInPlace(b, loc, normMap, scaleMap, norm, pScaleInv); - } - Value convertP = b.create(loc, sSizes, vETy); p = truncateFloat(b, loc, pMap, pMap, p, convertP); } @@ -384,14 +364,6 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // newAcc = P @ V + newAcc newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc); - - // Update for for the FP8 dynamic scale: - if (pScale) { - AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(), - /*symbolCount=*/0, getContext()); - newAcc = scaleValueInPlace(b, loc, accMap, scaleMap, newAcc, pScale); - } - return SmallVector{newAcc, newMax, newSum}; } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir index 3e323eda0c10..51fcc6dc2608 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir @@ -43,15 +43,15 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, // CHECK: arith.subf // CHECK: math.exp2 // CHECK: linalg.yield +// normSum = norm * oldSum +// CHECK: linalg.generic +// CHECK: arith.mulf +// CHECK: linalg.yield // norm = exp2(oldMax - newMax) // CHECK: linalg.generic // CHECK: arith.subf // CHECK: math.exp2 // CHECK: linalg.yield -// normSum = norm * oldSum -// CHECK: linalg.generic -// CHECK: arith.mulf -// CHECK: linalg.yield // newSum = normSum + rowMax(P) // CHECK: linalg.generic // CHECK: arith.addf @@ -107,11 +107,6 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // CHECK: linalg.generic // CHECK: arith.maximumf // CHECK: linalg.yield -// P = exp2(S - newMax) -// CHECK: linalg.generic -// CHECK: arith.subf -// CHECK: math.exp2 -// CHECK: linalg.yield // norm = exp2(oldMax - newMax) // CHECK: linalg.generic // CHECK: arith.subf @@ -121,6 +116,11 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // CHECK: linalg.generic // CHECK: arith.mulf // CHECK: linalg.yield +// P = exp2(S - newMax) +// CHECK: linalg.generic +// CHECK: arith.subf +// CHECK: math.exp2 +// CHECK: linalg.yield // newSum = normSum + rowMax(P) // CHECK: linalg.generic // CHECK: arith.addf @@ -128,8 +128,6 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // clamp = clamp(norm) // CHECK: linalg.generic // CHECK: arith.cmpf ogt -// CHECK: arith.cmpf olt -// CHECK: arith.select // CHECK: arith.select // CHECK: arith.truncf // newAcc = norm * oldAcc