From db7311bb45a1a12b37fc1e0ed5bc18001d7cb9fc Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 10 Nov 2023 14:09:51 +0530 Subject: [PATCH] Use rewriter based methods for replacing in TileAndDecomposeAttention (#15514) --- .../LinalgExt/Passes/TileAndDecomposeAttention.cpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp index 37fc2553fe3a..4054f93f4ec5 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp @@ -425,7 +425,7 @@ tileAttention(IREE::LinalgExt::AttentionOp attnOp, RewriterBase &rewriter) { insertOutputSlice(loopNest.results[0], output, sequenceTileLength, headDimension, loc, rewriter); - attnOp.getResults()[0].replaceAllUsesWith(loopNest.results[0]); + rewriter.replaceOp(attnOp, loopNest.results[0]); ops.push_back(tiledAttentionOp); return ops; } @@ -458,12 +458,7 @@ static void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp, keySlice, valueSlice, querySlice, tiledResult, max, sum, sequenceTileLength, headDimension, elementType, ops, loc, rewriter); - tiledAttnOp.getResults()[0].replaceAllUsesWith(result); - tiledAttnOp.getResults()[1].replaceAllUsesWith(newMax); - tiledAttnOp.getResults()[2].replaceAllUsesWith(newSum); - - OpBuilder::InsertionGuard afterScfLoop(rewriter); - rewriter.setInsertionPointAfter(tiledAttnOp->getParentOp()); + rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum}); } /// Utility function which tiles and then decomposes attention op via @@ -477,7 +472,7 @@ tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp, auto tiledAttnOp = cast(ops[ops.size() - 1]); ops.pop_back(); Operation *truncateToF16 = NULL; - Type elementType = attnOp.getQueryType().getElementType(); + Type elementType = tiledAttnOp.getQueryType().getElementType(); if (elementType.isF16()) { truncateToF16 = ops[ops.size() - 1]; ops.pop_back();