Skip to content

Commit

Permalink
[Codegen][llvmgpu] Refactor op cloning in prefetch shared memory pass (
Browse files Browse the repository at this point in the history
…iree-org#19196)

I refactored the prefetch shared memory pass by using `rewriter.clone()`
with a IRMapping. With this, the pass can now handle ops with a region
(like `scf.if`) which would otherwise create invalid IRs when there's
scf.if in k loop.

---------

Signed-off-by: jerryyin <zhuoryin@amd.com>
  • Loading branch information
jerryyin authored Dec 2, 2024
1 parent ecd87d8 commit 886f801
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,46 +250,15 @@ class LoopPrefetcher {
return success();
}

/// Clones |op| and call |callback| on the cloned op's operands as well as any
/// operands of nested ops that 1) aren't defined within the new op or 2) are
/// block arguments.
static Operation *
cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
function_ref<void(OpOperand *newOperand)> callback) {
Operation *clone = rewriter.clone(*op);
for (OpOperand &operand : clone->getOpOperands())
callback(&operand);
return clone;
}

/// Creates all read stage ops for a loop iteration with |rewriter| and maps
/// the original loop induction variable to |iv| in |mapping|.
SmallVector<Value> emitRead(IRMapping &mapping, RewriterBase &rewriter,
Value iv) {
void emitRead(IRMapping &mapping, RewriterBase &rewriter, Value iv) {
// Map the original loop induction variable to |iv| for later op rewrites.
mapping.map(forOp.getInductionVar(), iv);

SmallVector<Value> results;
for (Operation *op : readStage) {
// Clone the current read stage op and updates all its operands to
// reference newly created ops.
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
if (mapping.contains(newOperand->get())) {
newOperand->set(mapping.lookup(newOperand->get()));
}
});

if (isa<vector::TransferReadOp>(newOp)) {
llvm::append_range(results, newOp->getResults());
}

// Update read stage op results mapping.
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
mapping.map(op->getResult(i), newOp->getResult(i));
}
rewriter.clone(*op, mapping);
}
return results;
}

/// Creates all write stage ops for a loop iteration with |rewriter| and maps
Expand All @@ -299,22 +268,7 @@ class LoopPrefetcher {
mapping.map(forOp.getInductionVar(), iv);

for (Operation *op : writeStage) {
// Clone the current read stage op and updates all its operands to
// reference newly created ops.
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
if (mapping.contains(newOperand->get())) {
newOperand->set(mapping.lookup(newOperand->get()));
}
});

// If a mapping for any results already exists, move on, otherwise,
// add a new mapping.
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
if (!mapping.contains(op->getResult(i))) {
mapping.map(op->getResult(i), newOp->getResult(i));
}
}
rewriter.clone(*op, mapping);
}
}

Expand All @@ -341,18 +295,7 @@ class LoopPrefetcher {
break;
}

Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
if (mapping.contains(newOperand->get())) {
newOperand->set(mapping.lookup(newOperand->get()));
}
});
results = newOp->getResults();

// Map compute operations to new compute operations.
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
mapping.map(op->getResult(i), newOp->getResult(i));
}
rewriter.clone(*op, mapping);
}

return results;
Expand All @@ -361,12 +304,7 @@ class LoopPrefetcher {
void updateYield(IRMapping &mapping, RewriterBase &rewriter) {
for (Operation *op : computeStage) {
if (auto yield = dyn_cast<scf::YieldOp>(op)) {
cloneAndUpdateOperands(rewriter, yield, [&](OpOperand *newOperand) {
if (mapping.contains(newOperand->get())) {
newOperand->set(mapping.lookup(newOperand->get()));
}
});

rewriter.clone(*op, mapping);
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,56 @@ func.func @prefetch_multi_scf_return(%arg0: memref<128xf32>) -> (vector<1xf32>,
// CHECK: return %[[EPI_COMPUTE]], %[[EPI_COMPUTE2]]
return %0#0, %0#1 : vector<1xf32>, vector<1xf32>
}

// CHECK-LABEL: @prefetch_add_with_if
// CHECK-SAME: (%[[GLOBAL:.*]]: memref<128xf32>)
func.func @prefetch_add_with_if(%arg0: memref<128xf32>) {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
%cst = arith.constant dense<0.000000e+00> : vector<1xf32>
%cst_0 = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C127:.*]] = arith.constant 127 : index
%c128 = arith.constant 128 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
%true = arith.constant true
// CHECK-DAG: %[[SHARED:.*]] = memref.alloc() : memref<1xf32, #gpu.address_space<workgroup>>
%alloc = memref.alloc() : memref<1xf32, #gpu.address_space<workgroup>>
// CHECK-DAG: %[[PRO_READ:.*]] = vector.transfer_read %[[GLOBAL]]
// CHECK: vector.transfer_write %[[PRO_READ]], %[[SHARED]]
// CHECK: %[[OUT:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C127]] step %[[C1]] iter_args(%[[ARG:.*]] = %[[CST]])
%0 = scf.for %arg1 = %c0 to %c128 step %c1 iter_args(%arg2 = %cst) -> (vector<1xf32>) {
%dummy = memref.load %arg0[%arg1] : memref<128xf32>
%5 = arith.cmpf "oeq", %cst_0, %dummy : f32
// CHECK: %[[BRANCH:.*]] = scf.if %[[COND:.*]] -> (index)
// CHECK: } else {
// CHECK: }
%updated = scf.if %5 -> (index) {
%override = arith.constant 5 : index
%add = arith.addi %arg1, %override : index
scf.yield %add : index
} else {
scf.yield %arg1 : index
}
// CHECK: %[[KER_READ:.*]] = vector.transfer_read %[[GLOBAL]][%[[UPDATED:.*]]]
//%1 = vector.transfer_read %arg0[%arg1], %cst_0 : memref<128xf32>, vector<1xf32>
%1 = vector.transfer_read %arg0[%updated], %cst_0 : memref<128xf32>, vector<1xf32>
vector.transfer_write %1, %alloc[%c0] {in_bounds = [true]} : vector<1xf32>, memref<1xf32, #gpu.address_space<workgroup>>
// CHECK: gpu.barrier
// CHECK: %[[COMPUTE_READ:.*]] = vector.transfer_read %[[SHARED]][%[[C0]]]
%2 = vector.transfer_read %alloc[%c0], %cst_0 : memref<1xf32, #gpu.address_space<workgroup>>, vector<1xf32>
// CHECK: %[[COMPUTE:.*]] = arith.addf %[[COMPUTE_READ]], %[[ARG]]
%3 = arith.addf %2, %arg2 : vector<1xf32>
// CHECK: gpu.barrier
// CHECK: vector.transfer_write %[[KER_READ]], %[[SHARED]]
// CHECK: scf.yield %[[COMPUTE]]
scf.yield %3 : vector<1xf32>
}
// CHECK: gpu.barrier
// CHECK: %[[EPI_READ:.*]] = vector.transfer_read %[[SHARED]][%[[C0]]]
// CHECK: %[[EPI_COMPUTE:.*]] = arith.addf %[[EPI_READ]], %[[OUT]]
// CHECK: vector.transfer_write %[[EPI_COMPUTE]], %[[GLOBAL]][%[[C0]]]
vector.transfer_write %0, %arg0[%c0] {in_bounds = [true]} : vector<1xf32>, memref<128xf32>
return
}

0 comments on commit 886f801

Please sign in to comment.