diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp index 23ba79c2f665..b4e3d2e027ed 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp @@ -90,28 +90,19 @@ static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap, auto srcTy = cast(args[0].getType()); auto dstTy = cast(args[1].getType()); - // We clamp to the min / max of the floating point representation - double mnDbl = - APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/true) - .convertToDouble(); double mxDbl = APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false) .convertToDouble(); // Truncate to the `fp8` range so avoid nan values. - Value mn = builder.create( - loc, builder.getFloatAttr(srcTy, mnDbl)); Value mx = builder.create( loc, builder.getFloatAttr(srcTy, mxDbl)); Value gt = b.create(loc, arith::CmpFPredicate::OGT, args[0], mx); - Value lt = b.create(loc, arith::CmpFPredicate::OLT, - args[0], mn); Value sel0 = b.create(loc, gt, mx, args[0]); - Value sel1 = b.create(loc, lt, mn, sel0); // Convert scale to the same datatype as input. - Value trunc = convertScalarToDtype(b, loc, sel1, dstTy, + Value trunc = convertScalarToDtype(b, loc, sel0, dstTy, /*isUnsignedCast=*/false); b.create(loc, trunc); }); @@ -302,6 +293,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // SMap = QMap @ KMap Value emptyS = b.create(loc, sSizes, elementType); Value sZero = b.create(loc, b.getZeroAttr(elementType)); + Value s = b.create(loc, sZero, emptyS).getResult(0); s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s); @@ -323,11 +315,6 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap maxMap = getMaxMap(); Value newMax = reduce(b, loc, sMap, maxMap, s, oldMax); - // P = exp2(S - newMax) - // PMap = SMap - AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); - // norm = exp2(oldMax - newMax) // normMap = maxMap AffineMap normMap = getMaxMap(); @@ -337,6 +324,27 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap sumMap = getSumMap(); Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm); + // P = exp2(S - newMax) + // PMap = SMap + AffineMap pMap = sMap; + Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); + + // If we need to truncate to fp8 post softmax we apply a scaling to use the + // full fp8 range. We can do this with a offset as post `exp2` this equates + // to multiplying by a static value. We are able to do this as `max` and `sum` + // are scaled by the same value so the end result is the same. + if (isa(qETy) && qETy.getIntOrFloatBitWidth() == 8) { + auto fpTy = cast(qETy); + double mx = + APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false) + .convertToDouble(); + Value offset = + b.create(loc, b.getFloatAttr(elementType, mx)); + AffineMap scaleMap = AffineMap::get(/*dimCount=*/pMap.getNumInputs(), + /*symbolCount=*/0, getContext()); + p = scaleValueInPlace(b, loc, pMap, scaleMap, p, offset); + } + // newSum = normSum + rowSum(P) Value newSum = reduce(b, loc, pMap, sumMap, p, normSum); @@ -344,36 +352,8 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap accMap = getOutputMap(); // ---- Scale and truncate LHS to match RHS ---- - Value pScale; auto pETy = getElementTypeOrSelf(p.getType()); if (pETy != vETy && isa(vETy)) { - if (vETy.getIntOrFloatBitWidth() <= 8) { - SmallVector mSizes( - llvm::map_range(maxMap.getResults(), [&](AffineExpr dimExpr) { - return sizes[cast(dimExpr).getPosition()]; - })); - - auto fpTy = cast(vETy); - double largestDbl = - APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false) - .convertToDouble(); - - // We normalize p from [0, max] to [0, fp8.max] to guarantee we - // use the full `fp8` range, then renormlize post Softmax@V matmul - // to correct. - pScale = b.create( - loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / largestDbl)); - - // Compute the pre matmul scale to handle fp8 quantization: - Value pScaleInv = b.create( - loc, b.getFloatAttr(elementType, largestDbl / clAttentionSoftmaxMax)); - - AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(), - /*symbolCount=*/0, getContext()); - p = scaleValueInPlace(b, loc, pMap, scaleMap, p, pScaleInv); - norm = scaleValueInPlace(b, loc, normMap, scaleMap, norm, pScaleInv); - } - Value convertP = b.create(loc, sSizes, vETy); p = truncateFloat(b, loc, pMap, pMap, p, convertP); } @@ -384,14 +364,6 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // newAcc = P @ V + newAcc newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc); - - // Update for for the FP8 dynamic scale: - if (pScale) { - AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(), - /*symbolCount=*/0, getContext()); - newAcc = scaleValueInPlace(b, loc, accMap, scaleMap, newAcc, pScale); - } - return SmallVector{newAcc, newMax, newSum}; } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir index 3e323eda0c10..51fcc6dc2608 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir @@ -43,15 +43,15 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, // CHECK: arith.subf // CHECK: math.exp2 // CHECK: linalg.yield +// normSum = norm * oldSum +// CHECK: linalg.generic +// CHECK: arith.mulf +// CHECK: linalg.yield // norm = exp2(oldMax - newMax) // CHECK: linalg.generic // CHECK: arith.subf // CHECK: math.exp2 // CHECK: linalg.yield -// normSum = norm * oldSum -// CHECK: linalg.generic -// CHECK: arith.mulf -// CHECK: linalg.yield // newSum = normSum + rowMax(P) // CHECK: linalg.generic // CHECK: arith.addf @@ -107,11 +107,6 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // CHECK: linalg.generic // CHECK: arith.maximumf // CHECK: linalg.yield -// P = exp2(S - newMax) -// CHECK: linalg.generic -// CHECK: arith.subf -// CHECK: math.exp2 -// CHECK: linalg.yield // norm = exp2(oldMax - newMax) // CHECK: linalg.generic // CHECK: arith.subf @@ -121,6 +116,11 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // CHECK: linalg.generic // CHECK: arith.mulf // CHECK: linalg.yield +// P = exp2(S - newMax) +// CHECK: linalg.generic +// CHECK: arith.subf +// CHECK: math.exp2 +// CHECK: linalg.yield // newSum = normSum + rowMax(P) // CHECK: linalg.generic // CHECK: arith.addf @@ -128,8 +128,6 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // clamp = clamp(norm) // CHECK: linalg.generic // CHECK: arith.cmpf ogt -// CHECK: arith.cmpf olt -// CHECK: arith.select // CHECK: arith.select // CHECK: arith.truncf // newAcc = norm * oldAcc