Skip to content

Commit

Permalink
Use rewriter based methods for replacing in TileAndDecomposeAttention (
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss authored Nov 10, 2023
1 parent f66f28f commit db7311b
Showing 1 changed file with 3 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ tileAttention(IREE::LinalgExt::AttentionOp attnOp, RewriterBase &rewriter) {
insertOutputSlice(loopNest.results[0], output, sequenceTileLength,
headDimension, loc, rewriter);

attnOp.getResults()[0].replaceAllUsesWith(loopNest.results[0]);
rewriter.replaceOp(attnOp, loopNest.results[0]);
ops.push_back(tiledAttentionOp);
return ops;
}
Expand Down Expand Up @@ -458,12 +458,7 @@ static void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
keySlice, valueSlice, querySlice, tiledResult, max, sum,
sequenceTileLength, headDimension, elementType, ops, loc, rewriter);

tiledAttnOp.getResults()[0].replaceAllUsesWith(result);
tiledAttnOp.getResults()[1].replaceAllUsesWith(newMax);
tiledAttnOp.getResults()[2].replaceAllUsesWith(newSum);

OpBuilder::InsertionGuard afterScfLoop(rewriter);
rewriter.setInsertionPointAfter(tiledAttnOp->getParentOp());
rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum});
}

/// Utility function which tiles and then decomposes attention op via
Expand All @@ -477,7 +472,7 @@ tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
auto tiledAttnOp = cast<IREE::LinalgExt::AttentionOp>(ops[ops.size() - 1]);
ops.pop_back();
Operation *truncateToF16 = NULL;
Type elementType = attnOp.getQueryType().getElementType();
Type elementType = tiledAttnOp.getQueryType().getElementType();
if (elementType.isF16()) {
truncateToF16 = ops[ops.size() - 1];
ops.pop_back();
Expand Down

0 comments on commit db7311b

Please sign in to comment.