diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/PrefetchSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/PrefetchSharedMemoryCopy.cpp index da5f5430be5e..ed96a1329db3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/PrefetchSharedMemoryCopy.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/PrefetchSharedMemoryCopy.cpp @@ -250,46 +250,15 @@ class LoopPrefetcher { return success(); } - /// Clones |op| and call |callback| on the cloned op's operands as well as any - /// operands of nested ops that 1) aren't defined within the new op or 2) are - /// block arguments. - static Operation * - cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, - function_ref callback) { - Operation *clone = rewriter.clone(*op); - for (OpOperand &operand : clone->getOpOperands()) - callback(&operand); - return clone; - } - /// Creates all read stage ops for a loop iteration with |rewriter| and maps /// the original loop induction variable to |iv| in |mapping|. - SmallVector emitRead(IRMapping &mapping, RewriterBase &rewriter, - Value iv) { + void emitRead(IRMapping &mapping, RewriterBase &rewriter, Value iv) { // Map the original loop induction variable to |iv| for later op rewrites. mapping.map(forOp.getInductionVar(), iv); - SmallVector results; for (Operation *op : readStage) { - // Clone the current read stage op and updates all its operands to - // reference newly created ops. - Operation *newOp = - cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { - if (mapping.contains(newOperand->get())) { - newOperand->set(mapping.lookup(newOperand->get())); - } - }); - - if (isa(newOp)) { - llvm::append_range(results, newOp->getResults()); - } - - // Update read stage op results mapping. - for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { - mapping.map(op->getResult(i), newOp->getResult(i)); - } + rewriter.clone(*op, mapping); } - return results; } /// Creates all write stage ops for a loop iteration with |rewriter| and maps @@ -299,22 +268,7 @@ class LoopPrefetcher { mapping.map(forOp.getInductionVar(), iv); for (Operation *op : writeStage) { - // Clone the current read stage op and updates all its operands to - // reference newly created ops. - Operation *newOp = - cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { - if (mapping.contains(newOperand->get())) { - newOperand->set(mapping.lookup(newOperand->get())); - } - }); - - // If a mapping for any results already exists, move on, otherwise, - // add a new mapping. - for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { - if (!mapping.contains(op->getResult(i))) { - mapping.map(op->getResult(i), newOp->getResult(i)); - } - } + rewriter.clone(*op, mapping); } } @@ -341,18 +295,7 @@ class LoopPrefetcher { break; } - Operation *newOp = - cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { - if (mapping.contains(newOperand->get())) { - newOperand->set(mapping.lookup(newOperand->get())); - } - }); - results = newOp->getResults(); - - // Map compute operations to new compute operations. - for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { - mapping.map(op->getResult(i), newOp->getResult(i)); - } + rewriter.clone(*op, mapping); } return results; @@ -361,12 +304,7 @@ class LoopPrefetcher { void updateYield(IRMapping &mapping, RewriterBase &rewriter) { for (Operation *op : computeStage) { if (auto yield = dyn_cast(op)) { - cloneAndUpdateOperands(rewriter, yield, [&](OpOperand *newOperand) { - if (mapping.contains(newOperand->get())) { - newOperand->set(mapping.lookup(newOperand->get())); - } - }); - + rewriter.clone(*op, mapping); break; } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir index 68003274de0b..87cfeb5004c4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir @@ -81,3 +81,56 @@ func.func @prefetch_multi_scf_return(%arg0: memref<128xf32>) -> (vector<1xf32>, // CHECK: return %[[EPI_COMPUTE]], %[[EPI_COMPUTE2]] return %0#0, %0#1 : vector<1xf32>, vector<1xf32> } + +// CHECK-LABEL: @prefetch_add_with_if +// CHECK-SAME: (%[[GLOBAL:.*]]: memref<128xf32>) +func.func @prefetch_add_with_if(%arg0: memref<128xf32>) { + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> + %cst = arith.constant dense<0.000000e+00> : vector<1xf32> + %cst_0 = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C127:.*]] = arith.constant 127 : index + %c128 = arith.constant 128 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + %true = arith.constant true + // CHECK-DAG: %[[SHARED:.*]] = memref.alloc() : memref<1xf32, #gpu.address_space> + %alloc = memref.alloc() : memref<1xf32, #gpu.address_space> + // CHECK-DAG: %[[PRO_READ:.*]] = vector.transfer_read %[[GLOBAL]] + // CHECK: vector.transfer_write %[[PRO_READ]], %[[SHARED]] + // CHECK: %[[OUT:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C127]] step %[[C1]] iter_args(%[[ARG:.*]] = %[[CST]]) + %0 = scf.for %arg1 = %c0 to %c128 step %c1 iter_args(%arg2 = %cst) -> (vector<1xf32>) { + %dummy = memref.load %arg0[%arg1] : memref<128xf32> + %5 = arith.cmpf "oeq", %cst_0, %dummy : f32 + // CHECK: %[[BRANCH:.*]] = scf.if %[[COND:.*]] -> (index) + // CHECK: } else { + // CHECK: } + %updated = scf.if %5 -> (index) { + %override = arith.constant 5 : index + %add = arith.addi %arg1, %override : index + scf.yield %add : index + } else { + scf.yield %arg1 : index + } + // CHECK: %[[KER_READ:.*]] = vector.transfer_read %[[GLOBAL]][%[[UPDATED:.*]]] + //%1 = vector.transfer_read %arg0[%arg1], %cst_0 : memref<128xf32>, vector<1xf32> + %1 = vector.transfer_read %arg0[%updated], %cst_0 : memref<128xf32>, vector<1xf32> + vector.transfer_write %1, %alloc[%c0] {in_bounds = [true]} : vector<1xf32>, memref<1xf32, #gpu.address_space> + // CHECK: gpu.barrier + // CHECK: %[[COMPUTE_READ:.*]] = vector.transfer_read %[[SHARED]][%[[C0]]] + %2 = vector.transfer_read %alloc[%c0], %cst_0 : memref<1xf32, #gpu.address_space>, vector<1xf32> + // CHECK: %[[COMPUTE:.*]] = arith.addf %[[COMPUTE_READ]], %[[ARG]] + %3 = arith.addf %2, %arg2 : vector<1xf32> + // CHECK: gpu.barrier + // CHECK: vector.transfer_write %[[KER_READ]], %[[SHARED]] + // CHECK: scf.yield %[[COMPUTE]] + scf.yield %3 : vector<1xf32> + } + // CHECK: gpu.barrier + // CHECK: %[[EPI_READ:.*]] = vector.transfer_read %[[SHARED]][%[[C0]]] + // CHECK: %[[EPI_COMPUTE:.*]] = arith.addf %[[EPI_READ]], %[[OUT]] + // CHECK: vector.transfer_write %[[EPI_COMPUTE]], %[[GLOBAL]][%[[C0]]] + vector.transfer_write %0, %arg0[%c0] {in_bounds = [true]} : vector<1xf32>, memref<128xf32> + return +}