Skip to content

Commit

Permalink
Optimize fp8 linalg_ext.attention by rework Q@K scaling (iree-org…
Browse files Browse the repository at this point in the history
…#18031)

Moving where the Q@k scaling and allowing the max scale to carry of into
the softmax sum
achieves the same outcome for softmax while removing the rescaling. This
does not affect
precision as the original plan of applying a Q@K offset originally did.
  • Loading branch information
rsuderman authored Jul 30, 2024
1 parent 6c45bef commit 2c53b4a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,19 @@ static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap,
auto srcTy = cast<FloatType>(args[0].getType());
auto dstTy = cast<FloatType>(args[1].getType());

// We clamp to the min / max of the floating point representation
double mnDbl =
APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/true)
.convertToDouble();
double mxDbl =
APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();

// Truncate to the `fp8` range so avoid nan values.
Value mn = builder.create<arith::ConstantOp>(
loc, builder.getFloatAttr(srcTy, mnDbl));
Value mx = builder.create<arith::ConstantOp>(
loc, builder.getFloatAttr(srcTy, mxDbl));
Value gt = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
args[0], mx);
Value lt = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
args[0], mn);
Value sel0 = b.create<arith::SelectOp>(loc, gt, mx, args[0]);
Value sel1 = b.create<arith::SelectOp>(loc, lt, mn, sel0);

// Convert scale to the same datatype as input.
Value trunc = convertScalarToDtype(b, loc, sel1, dstTy,
Value trunc = convertScalarToDtype(b, loc, sel0, dstTy,
/*isUnsignedCast=*/false);
b.create<linalg::YieldOp>(loc, trunc);
});
Expand Down Expand Up @@ -302,6 +293,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
// SMap = QMap @ KMap
Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, elementType);
Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));

Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);

Expand All @@ -323,11 +315,6 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
AffineMap maxMap = getMaxMap();
Value newMax = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, oldMax);

// P = exp2(S - newMax)
// PMap = SMap
AffineMap pMap = sMap;
Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);

// norm = exp2(oldMax - newMax)
// normMap = maxMap
AffineMap normMap = getMaxMap();
Expand All @@ -337,43 +324,36 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
AffineMap sumMap = getSumMap();
Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm);

// P = exp2(S - newMax)
// PMap = SMap
AffineMap pMap = sMap;
Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);

// If we need to truncate to fp8 post softmax we apply a scaling to use the
// full fp8 range. We can do this with a offset as post `exp2` this equates
// to multiplying by a static value. We are able to do this as `max` and `sum`
// are scaled by the same value so the end result is the same.
if (isa<FloatType>(qETy) && qETy.getIntOrFloatBitWidth() == 8) {
auto fpTy = cast<FloatType>(qETy);
double mx =
APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();
Value offset =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, mx));
AffineMap scaleMap = AffineMap::get(/*dimCount=*/pMap.getNumInputs(),
/*symbolCount=*/0, getContext());
p = scaleValueInPlace(b, loc, pMap, scaleMap, p, offset);
}

// newSum = normSum + rowSum(P)
Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);

// newAcc = norm * oldAcc
AffineMap accMap = getOutputMap();

// ---- Scale and truncate LHS to match RHS ----
Value pScale;
auto pETy = getElementTypeOrSelf(p.getType());
if (pETy != vETy && isa<FloatType>(vETy)) {
if (vETy.getIntOrFloatBitWidth() <= 8) {
SmallVector<OpFoldResult> mSizes(
llvm::map_range(maxMap.getResults(), [&](AffineExpr dimExpr) {
return sizes[cast<AffineDimExpr>(dimExpr).getPosition()];
}));

auto fpTy = cast<FloatType>(vETy);
double largestDbl =
APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();

// We normalize p from [0, max] to [0, fp8.max] to guarantee we
// use the full `fp8` range, then renormlize post Softmax@V matmul
// to correct.
pScale = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / largestDbl));

// Compute the pre matmul scale to handle fp8 quantization:
Value pScaleInv = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(elementType, largestDbl / clAttentionSoftmaxMax));

AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
/*symbolCount=*/0, getContext());
p = scaleValueInPlace(b, loc, pMap, scaleMap, p, pScaleInv);
norm = scaleValueInPlace(b, loc, normMap, scaleMap, norm, pScaleInv);
}

Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
p = truncateFloat(b, loc, pMap, pMap, p, convertP);
}
Expand All @@ -384,14 +364,6 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {

// newAcc = P @ V + newAcc
newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);

// Update for for the FP8 dynamic scale:
if (pScale) {
AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
/*symbolCount=*/0, getContext());
newAcc = scaleValueInPlace(b, loc, accMap, scaleMap, newAcc, pScale);
}

return SmallVector<Value>{newAcc, newMax, newSum};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>,
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// normSum = norm * oldSum
// CHECK: linalg.generic
// CHECK: arith.mulf
// CHECK: linalg.yield
// norm = exp2(oldMax - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// normSum = norm * oldSum
// CHECK: linalg.generic
// CHECK: arith.mulf
// CHECK: linalg.yield
// newSum = normSum + rowMax(P)
// CHECK: linalg.generic
// CHECK: arith.addf
Expand Down Expand Up @@ -107,11 +107,6 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
// CHECK: linalg.generic
// CHECK: arith.maximumf
// CHECK: linalg.yield
// P = exp2(S - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// norm = exp2(oldMax - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
Expand All @@ -121,15 +116,18 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
// CHECK: linalg.generic
// CHECK: arith.mulf
// CHECK: linalg.yield
// P = exp2(S - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// newSum = normSum + rowMax(P)
// CHECK: linalg.generic
// CHECK: arith.addf
// CHECK: linalg.yield
// clamp = clamp(norm)
// CHECK: linalg.generic
// CHECK: arith.cmpf ogt
// CHECK: arith.cmpf olt
// CHECK: arith.select
// CHECK: arith.select
// CHECK: arith.truncf
// newAcc = norm * oldAcc
Expand Down

0 comments on commit 2c53b4a

Please sign in to comment.