From 76c3e61d563dbc22b74aa9d3d79c11c24a799697 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Thu, 26 Sep 2024 16:42:15 -0700 Subject: [PATCH] [CodeGen] Fix the argument replacements in scf.forall op lowering. (#18613) They should be scaled by tile sizes. Otherwise, we always access the same memory chunk. Signed-off-by: hanhanW --- .../Common/ReconcileTranslationInfo.cpp | 8 +++++ .../test/reconcile_translation_info.mlir | 32 ++++++++++++++++--- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp index d7a81fca0da1..f09d3a693b24 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp @@ -211,6 +211,14 @@ static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter, Block *parentBlock = forallOp->getBlock(); Block *remainingBlock = rewriter.splitBlock(parentBlock, Block::iterator(forallOp)); + for (auto [id, step] : llvm::zip_equal(procId, mixedStep)) { + rewriter.setInsertionPointToEnd(parentBlock); + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + AffineExpr expr = s1 * s0; + id = affine::makeComposedFoldedAffineApply(rewriter, forallOp.getLoc(), + expr, {id, step}); + } auto argReplacements = getValueOrCreateConstantIndexOp(rewriter, forallOp.getLoc(), procId); Block *loopBody = forallOp.getBody(); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir index cc8aa9de94f7..fa56c6d55c94 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir @@ -191,6 +191,8 @@ hal.executable private @scf_forall_2D { } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64) // CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32) +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 64)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 32)> // CHECK: hal.executable.export public @scf_forall_2D layout // CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index @@ -203,7 +205,9 @@ hal.executable private @scf_forall_2D { // CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1] // CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0] // CHECK-NOT: scf.forall -// CHECK: "use"(%[[WG_ID_Y]], %[[WG_ID_X]]) +// CHECK: %[[I:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]]] +// CHECK: %[[J:.+]] = affine.apply #[[MAP3]]()[%[[WG_ID_X]]] +// CHECK: "use"(%[[I]], %[[J]]) // ----- @@ -236,6 +240,7 @@ hal.executable private @scf_forall_2D_dynamic_tile_size { } } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1) +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0)> // CHECK: hal.executable.export public @scf_forall_2D_dynamic_tile_size layout // CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index @@ -246,10 +251,14 @@ hal.executable private @scf_forall_2D_dynamic_tile_size { // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK: hal.return %[[WG_X]], %[[WG_Y]], %[[C1]] // CHECK: func @scf_forall_2D_dynamic_tile_size() +// CHECK-DAG: %[[STEP_Y:.+]] = hal.interface.constant.load {{.+}} ordinal(2) +// CHECK-DAG: %[[STEP_X:.+]] = hal.interface.constant.load {{.+}} ordinal(3) // CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1] // CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0] // CHECK-NOT: scf.forall -// CHECK: "use"(%[[WG_ID_Y]], %[[WG_ID_X]]) +// CHECK: %[[I:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_Y]], %[[STEP_Y]]] +// CHECK: %[[J:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]], %[[STEP_X]]] +// CHECK: "use"(%[[I]], %[[J]]) // ----- @@ -305,6 +314,7 @@ hal.executable private @scf_forall_4D { } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((-s0 + s1) ceildiv s2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> (((-s0 + s1) ceildiv s2) * ((-s3 + s4) ceildiv s5))> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1] -> (s1 * s0)> // CHECK: hal.executable.export public @scf_forall_4D layout // CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index @@ -329,6 +339,8 @@ hal.executable private @scf_forall_4D { // CHECK-DAG: %[[UB1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(5) // CHECK-DAG: %[[STEP0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(8) // CHECK-DAG: %[[STEP1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(9) +// CHECK-DAG: %[[STEP2:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(10) +// CHECK-DAG: %[[STEP3:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(11) // CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0] // CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1] // CHECK-DAG: %[[NITERS1:.+]] = affine.apply #[[MAP0]]()[%[[LB1]], %[[UB1]], %[[STEP1]]] @@ -336,7 +348,11 @@ hal.executable private @scf_forall_4D { // CHECK-DAG: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2] // CHECK-NOT: scf.forall // CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[WG_ID_Z]] into (%[[NITERS0]], %[[NITERS1]]) -// CHECK: "use"(%[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[WG_ID_Y]], %[[WG_ID_X]]) +// CHECK: %[[I:.+]] = affine.apply #[[MAP2]]()[%[[DELINEARIZE]]#0, %[[STEP0]]] +// CHECK: %[[J:.+]] = affine.apply #[[MAP2]]()[%[[DELINEARIZE]]#1, %[[STEP1]]] +// CHECK: %[[K:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]], %[[STEP2]]] +// CHECK: %[[L:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_X]], %[[STEP3]]] +// CHECK: "use"(%[[I]], %[[J]], %[[K]], %[[L]]) // ----- @@ -364,6 +380,10 @@ hal.executable private @scf_forall_4D_static_interchange { } } } +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 5)> // CHECK: hal.executable.export public @scf_forall_4D_static_interchange layout // CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index // CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index @@ -378,7 +398,11 @@ hal.executable private @scf_forall_4D_static_interchange { // CHECK-DAG: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2] // CHECK-NOT: scf.forall // CHECK: %[[DELINEARIZE:.+]]:3 = affine.delinearize_index %[[WG_ID_Z]] into (%[[C5]], %[[C8]], %[[C4]]) -// CHECK: "use"(%[[DELINEARIZE]]#2, %[[DELINEARIZE]]#0, %[[WG_ID_X]], %[[WG_ID_Y]], %[[DELINEARIZE]]#1) +// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[DELINEARIZE]]#0] +// CHECK: %[[J:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]] +// CHECK: %[[K:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]]] +// CHECK: %[[L:.+]] = affine.apply #[[MAP3]]()[%[[DELINEARIZE]]#1] +// CHECK: "use"(%[[DELINEARIZE]]#2, %[[I]], %[[J]], %[[K]], %[[L]]) // -----