From 02920e6b7d5d20d1eb2e2e0564f28bcd5d29f665 Mon Sep 17 00:00:00 2001 From: TB Schardl Date: Mon, 15 Jul 2024 16:24:53 -0400 Subject: [PATCH] [LoweringUtils,OpenCilkABI] Simplify logic for passing stack frames to spawn-helper functions, to simply pass the stackframe as the last argument. --- llvm/lib/Transforms/Tapir/LoweringUtils.cpp | 16 ++++---- llvm/lib/Transforms/Tapir/OpenCilkABI.cpp | 44 +++++---------------- 2 files changed, 19 insertions(+), 41 deletions(-) diff --git a/llvm/lib/Transforms/Tapir/LoweringUtils.cpp b/llvm/lib/Transforms/Tapir/LoweringUtils.cpp index 031e14ef43a2..124b9f57ed3a 100644 --- a/llvm/lib/Transforms/Tapir/LoweringUtils.cpp +++ b/llvm/lib/Transforms/Tapir/LoweringUtils.cpp @@ -471,7 +471,7 @@ llvm::createTaskArgsStruct(const ValueSet &Inputs, Task *T, /// Organize the set \p Inputs of values in \p F into a set \p Fixed of values /// that can be used as inputs to a helper function. void llvm::fixupInputSet(Function &F, const ValueSet &Inputs, ValueSet &Fixed) { - // Scan for any sret parameters in TaskInputs and add them first. These + // Scan for any sret parameters in Inputs and add them first. These // parameters must appear first or second in the prototype of the Helper // function. Value *SRetInput = nullptr; @@ -497,18 +497,20 @@ void llvm::fixupInputSet(Function &F, const ValueSet &Inputs, ValueSet &Fixed) { for (Value *V : Inputs) if (V != SRetInput) InputsToSort.push_back(V); - LLVM_DEBUG({ - dbgs() << "After sorting:\n"; - for (Value *V : InputsToSort) - dbgs() << "\t" << *V << "\n"; - }); + const DataLayout &DL = F.getParent()->getDataLayout(); std::sort(InputsToSort.begin(), InputsToSort.end(), [&DL](const Value *A, const Value *B) { return DL.getTypeSizeInBits(A->getType()) > - DL.getTypeSizeInBits(B->getType()); + DL.getTypeSizeInBits(B->getType()); }); + LLVM_DEBUG({ + dbgs() << "inputs after fixup:\n"; + for (Value *V : InputsToSort) + dbgs() << "\t" << *V << "\n"; + }); + // Add the remaining inputs. for (Value *V : InputsToSort) if (!Fixed.count(V)) diff --git a/llvm/lib/Transforms/Tapir/OpenCilkABI.cpp b/llvm/lib/Transforms/Tapir/OpenCilkABI.cpp index b5815cbc5e8f..7a3e11c3ca31 100644 --- a/llvm/lib/Transforms/Tapir/OpenCilkABI.cpp +++ b/llvm/lib/Transforms/Tapir/OpenCilkABI.cpp @@ -27,6 +27,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Verifier.h" #include "llvm/IRReader/IRReader.h" @@ -346,45 +347,18 @@ void OpenCilkABI::prepareModule() { } } -static bool isSRetInput(const Value *V, const Function &F) { - if (!isa(V)) - return false; - - const auto *ArgIter = F.arg_begin(); - if (F.hasParamAttribute(0, Attribute::StructRet) && V == &*ArgIter) - return true; - ++ArgIter; - if (F.hasParamAttribute(1, Attribute::StructRet) && V == &*ArgIter) - return true; - - return false; -} - void OpenCilkABI::setupTaskOutlineArgs(Function &F, ValueSet &HelperArgs, - SmallVectorImpl &HelperInputs, - const ValueSet &TaskHelperArgs) { - PointerType *SFPtrTy = PointerType::getUnqual(F.getContext()); - - // First add the sret task input, if it exists. - ValueSet::iterator TaskInputIter = TaskHelperArgs.begin(); - if ((TaskInputIter != TaskHelperArgs.end()) && isSRetInput(*TaskInputIter, F)) { - HelperArgs.insert(*TaskInputIter); - HelperInputs.push_back(*TaskInputIter); - ++TaskInputIter; - } + SmallVectorImpl &HelperInputs, + const ValueSet &TaskHelperArgs) { + TapirTarget::setupTaskOutlineArgs(F, HelperArgs, HelperInputs, + TaskHelperArgs); // Add a pointer for the parent stack frame. This pointer will be replaced // later in the call to the helper. + PointerType *SFPtrTy = PointerType::getUnqual(F.getContext()); Value *ParentSFArg = ConstantPointerNull::get(SFPtrTy); HelperArgs.insert(ParentSFArg); HelperInputs.push_back(ParentSFArg); - - // Add the remaining task input arguments. - while (TaskInputIter != TaskHelperArgs.end()) { - Value *V = *TaskInputIter++; - HelperArgs.insert(V); - HelperInputs.push_back(V); - } } void OpenCilkABI::addHelperAttributes(Function &Helper) { @@ -401,7 +375,7 @@ void OpenCilkABI::addHelperAttributes(Function &Helper) { Helper.getMemoryEffects() | MemoryEffects(MemoryEffects::Location::Other, ModRefInfo::ModRef)); } - // Note that the address of the helper is unimportant. + // Mark that the address of the helper is unimportant. Helper.setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // The helper is internal to this module. We use internal linkage, rather @@ -506,7 +480,7 @@ Value* OpenCilkABI::GetOrCreateCilkStackFrame(Function &F) { } static unsigned getParentSFArgNum(Function &H) { - return isSRetInput(H.getArg(0), H) ? 1 : 0; + return H.arg_size() - 1; } // Helper function to add a debug location to an IRBuilder if it otherwise lacks @@ -894,6 +868,8 @@ void OpenCilkABI::processSubTaskCall(TaskOutlineInfo &TOI, DominatorTree &DT) { Value *SF = DetachCtxToStackFrame[&F]; assert(SF && "No frame found for spawning task"); + // Find the helper argument for the parent __cilkrts_stack_frame and update + // the corresponding operand in the call. const unsigned ParentSFArgNum = getParentSFArgNum(*TOI.Outline); assert(ReplCall->getOperand(ParentSFArgNum) == ConstantPointerNull::get(PointerType::getUnqual(C)));