From d4852884a6ec77b0bbb7167d4cad889399c8180e Mon Sep 17 00:00:00 2001 From: TB Schardl Date: Fri, 19 Jan 2024 10:40:00 -0700 Subject: [PATCH] [TailRecursionElimination] Update the set of a return blocks before which to insert a sync if TRE eliminates that block. --- .../Scalar/TailRecursionElimination.cpp | 15 ++++-- .../Transforms/Tapir/tre-remove-return.ll | 53 +++++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) create mode 100644 llvm/test/Transforms/Tapir/tre-remove-return.ll diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index 09d0c272175f..6a53213c120d 100644 --- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -427,7 +427,7 @@ class TailRecursionEliminator { Instruction *AccumulatorRecursionInstr = nullptr; // Map from sync region to return blocks to sync for that sync region. - DenseMap> ReturnBlocksToSync; + DenseMap> ReturnBlocksToSync; TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI, AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, @@ -442,6 +442,8 @@ class TailRecursionEliminator { bool eliminateCall(CallInst *CI); + void RemoveReturnBlockToSync(BasicBlock *RetBlock); + void InsertSyncsIntoReturnBlocks(); void cleanupAndFinalize(); @@ -847,6 +849,11 @@ getReturnBlocksToSync(BasicBlock *Entry, SyncInst *Sync, } } +void TailRecursionEliminator::RemoveReturnBlockToSync(BasicBlock *RetBlock) { + for (auto &ReturnsToSync : ReturnBlocksToSync) + ReturnsToSync.second.erase(RetBlock); +} + static bool hasPrecedingSync(SyncInst *SI) { // TODO: Save the results from previous calls to hasPrecedingSync, in order to // speed up multiple calls to this routine for different sync instructions. @@ -941,8 +948,10 @@ bool TailRecursionEliminator::processBlock(BasicBlock &BB) { // because the ret instruction in there is still using a value which // eliminateCall will attempt to remove. This block can only contain // instructions that can't have uses, therefore it is safe to remove. - if (pred_empty(Succ)) + if (pred_empty(Succ)) { + RemoveReturnBlockToSync(Succ); DTU.deleteBB(Succ); + } eliminateCall(CI); return true; @@ -1065,7 +1074,7 @@ bool TailRecursionEliminator::processBlock(BasicBlock &BB) { // We defer the restoration of syncs at relevant return blocks until after // all blocks are processed. This approach simplifies the logic for // eliminating multiple tail calls that are only separated from the return - // by a sync, since the CFG won't be perturbed unnecessarily. + // by a sync, since the CFG won't be changed unnecessarily. } else { // Restore the sync that was eliminated. BasicBlock *RetBlock = Ret->getParent(); diff --git a/llvm/test/Transforms/Tapir/tre-remove-return.ll b/llvm/test/Transforms/Tapir/tre-remove-return.ll new file mode 100644 index 000000000000..ea01fd4aa742 --- /dev/null +++ b/llvm/test/Transforms/Tapir/tre-remove-return.ll @@ -0,0 +1,53 @@ +; Check that tail-call elimination handles deletion of return blocks +; when it attempts to insert sync instructions before returns. +; +; RUN: opt < %s -passes="cgscc(devirt<4>(function(tailcallelim)))" -S | FileCheck %s +target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128" +target triple = "arm64-apple-macosx14.0.0" + +define void @_Z3dacPfPKfS1_xxxxxx() personality ptr null { +entry: + %syncreg = call token @llvm.syncregion.start() + %0 = call token @llvm.tapir.runtime.start() + br i1 false, label %if.end10, label %if.then7 + +if.then7: ; preds = %entry + call void @_Z3dacPfPKfS1_xxxxxx() + sync within %syncreg, label %cleanup + +if.end10: ; preds = %entry + call void @_Z3dacPfPKfS1_xxxxxx() + br label %cleanup + +cleanup: ; preds = %if.end10, %if.then7 + ret void +} + +; CHECK: define void @_Z3dacPfPKfS1_xxxxxx() + +; CHECK: entry: +; CHECK-NEXT: %syncreg = {{.*}}call token @llvm.syncregion.start() +; CHECK-NEXT: br label %[[TAILRECURSE:.+]] + +; CHECK: [[TAILRECURSE]]: +; CHECK: br i1 false, label %if.end10, label %if.then7 + +; CHECK: if.then7: +; CHECK-NOT: sync +; CHECK-NEXT: br label %[[TAILRECURSE]] + +; CHECK: if.end10: +; CHECK-NEXT: br label %[[TAILRECURSE]] + +; CHECK-NOT: ret void + +; Function Attrs: nounwind willreturn memory(argmem: readwrite) +declare token @llvm.tapir.runtime.start() #0 + +; Function Attrs: nounwind willreturn memory(argmem: readwrite) +declare token @llvm.syncregion.start() #0 + +; uselistorder directives +uselistorder ptr null, { 1, 2, 0 } + +attributes #0 = { nounwind willreturn memory(argmem: readwrite) }