Skip to content

Commit

Permalink
Wire through coarse-fences invocation model.
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Aug 29, 2024
1 parent b8b86bc commit 5938ff9
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 23 deletions.
122 changes: 102 additions & 20 deletions libshortfin/src/shortfin/local/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,31 @@ void GetVmModuleExports(iree_vm_module_t *vm_module,
// ProgramFunction
// -------------------------------------------------------------------------- //

ProgramFunction::ProgramFunction(
iree::vm_context_ptr vm_context, iree_vm_function_t vm_function,
std::optional<ProgramInvocationModel> invocation_model)
: vm_context_(std::move(vm_context)),
vm_function_(vm_function),
invocation_model_(invocation_model
? *invocation_model
: GetInvocationModelFromFunction(vm_function)) {}

ProgramInvocationModel ProgramFunction::GetInvocationModelFromFunction(
iree_vm_function_t &f) {
iree_string_view_t invocation_model_sv =
iree_vm_function_lookup_attr_by_name(&f, IREE_SV("iree.abi.model"));
if (iree_string_view_equal(invocation_model_sv, IREE_SV("coarse-fences"))) {
return ProgramInvocationModel::COARSE_FENCES;
} else if (invocation_model_sv.size == 0) {
return ProgramInvocationModel::NONE;
} else {
logging::warn("Unknown function invocation model '{}': '{}'",
to_string_view(iree_vm_function_name(&f)),
to_string_view(invocation_model_sv));
return ProgramInvocationModel::UNKNOWN;
}
}

std::string_view ProgramFunction::name() const {
if (!*this) return {};
return to_string_view(iree_vm_function_name(&vm_function_));
Expand All @@ -45,7 +70,8 @@ std::string_view ProgramFunction::calling_convention() const {

ProgramInvocation::Ptr ProgramFunction::CreateInvocation(
std::shared_ptr<Scope> scope) {
return ProgramInvocation::New(std::move(scope), vm_context_, vm_function_);
return ProgramInvocation::New(std::move(scope), vm_context_, vm_function_,
invocation_model_);
}

std::string ProgramFunction::to_s() const {
Expand Down Expand Up @@ -108,7 +134,29 @@ std::vector<std::string> ProgramModule::exports() const {
// -------------------------------------------------------------------------- //

std::optional<ProgramFunction> Program::LookupFunction(std::string_view name) {
// By convention, we currently name our coarse-fences function variants
// as ending in "$async". These are the ones we want but it is inconvenient.
// Therefore, we probe for that first.
// TODO: We should add attributes to the function that better describe this
// relationship.
iree_vm_function_t f;
if (!name.ends_with("$async")) {
std::string async_name(name);
async_name.append("$async");
iree_status_t status = iree_vm_context_resolve_function(
vm_context_, to_iree_string_view(async_name), &f);
if (iree_status_is_ok(status)) {
// TODO: Torch import is not setting the coarse-fences abi.model on
// its functions. Get it from there instead of just assuming based on
// name.
return ProgramFunction(vm_context_, f,
ProgramInvocationModel::COARSE_FENCES);
} else if (!iree_status_is_not_found(status)) {
SHORTFIN_THROW_IF_ERROR(status);
}
}

// Resolve the exactly named function.
iree_status_t status = iree_vm_context_resolve_function(
vm_context_, to_iree_string_view(name), &f);
if (iree_status_is_not_found(status)) return {};
Expand Down Expand Up @@ -173,9 +221,9 @@ ProgramInvocation::~ProgramInvocation() {
}
}

ProgramInvocation::Ptr ProgramInvocation::New(std::shared_ptr<Scope> scope,
iree::vm_context_ptr vm_context,
iree_vm_function_t &vm_function) {
ProgramInvocation::Ptr ProgramInvocation::New(
std::shared_ptr<Scope> scope, iree::vm_context_ptr vm_context,
iree_vm_function_t &vm_function, ProgramInvocationModel invocation_model) {
auto sig = iree_vm_function_signature(&vm_function);
iree_host_size_t arg_count;
iree_host_size_t result_count;
Expand Down Expand Up @@ -217,6 +265,7 @@ ProgramInvocation::Ptr ProgramInvocation::New(std::shared_ptr<Scope> scope,
inst->state.params.context =
vm_context.release(); // Ref transfer to ProgramInvocation.
inst->state.params.function = vm_function;
inst->state.params.invocation_model = invocation_model;
inst->state.params.arg_list = arg_list;
inst->result_list_ = result_list;
return inst;
Expand All @@ -243,11 +292,30 @@ void ProgramInvocation::AddArg(iree_vm_ref_t *ref) {
ProgramInvocation::Future ProgramInvocation::Invoke(
ProgramInvocation::Ptr invocation) {
invocation->CheckNotScheduled();

Worker &worker = invocation->scope_->worker();
// We're about to overwrite the instance level storage for params, so move
// it to the stack and access there.
Params params = invocation->state.params;

// Handle post-processing invocation model setup.
if (params.invocation_model == ProgramInvocationModel::COARSE_FENCES) {
iree_vm_ref_t wait_ref =
iree_hal_fence_retain_ref(invocation->wait_fence());
SHORTFIN_THROW_IF_ERROR(
iree_vm_list_push_ref_move(params.arg_list, &wait_ref));
iree_vm_ref_t signal_ref =
iree_hal_fence_retain_ref(invocation->signal_fence());
SHORTFIN_THROW_IF_ERROR(
iree_vm_list_push_ref_move(params.arg_list, &signal_ref));
} else {
logging::warn(
"Invoking function '{}' with unknown or synchronous invocation model "
"is "
"not fully supported",
to_string_view(iree_vm_function_name(&params.function)));
}

auto schedule = [](ProgramInvocation *raw_invocation, Worker *worker,
iree_vm_context_t *owned_context,
iree_vm_function_t function, iree_vm_list_t *arg_list,
Expand All @@ -256,11 +324,11 @@ ProgramInvocation::Future ProgramInvocation::Invoke(
[](void *user_data, iree_loop_t loop, iree_status_t status,
iree_vm_list_t *outputs) noexcept -> iree_status_t {
// Async invocation helpfully gives us a retained reference to the
// outputs, but we already have one statically on the ProgramInvocation.
// So release this one, which makes it safe to deallocate the
// ProgramInvocation at any point after this (there must be no live
// references to inputs/outputs when the ProgramInvocation::Ptr deleter is
// invoked).
// outputs, but we already have one statically on the
// ProgramInvocation. So release this one, which makes it safe to
// deallocate the ProgramInvocation at any point after this (there
// must be no live references to inputs/outputs when the
// ProgramInvocation::Ptr deleter is invoked).
iree::vm_list_ptr::steal_reference(outputs);

// Repatriate the ProgramInvocation.
Expand All @@ -273,17 +341,17 @@ ProgramInvocation::Future ProgramInvocation::Invoke(
raw_invocation->future_->set_failure(status);
}

// Must release the future from the invocation to break the circular
// reference (we are setting the invocation as the result of the
// future).
// Must release the future from the invocation to break the
// circular reference (we are setting the invocation as the result
// of the future).
raw_invocation->future_.reset();

return iree_ok_status();
};

ProgramInvocation::Ptr invocation(raw_invocation);
// TODO: Need to fork based on whether on the current worker. If not, then
// do cross thread scheduling.
// TODO: Need to fork based on whether on the current worker. If
// not, then do cross thread scheduling.
iree_status_t status = iree_vm_async_invoke(
worker->loop(), &invocation->state.async_invoke_state, owned_context,
function,
Expand All @@ -294,14 +362,14 @@ ProgramInvocation::Future ProgramInvocation::Invoke(
+complete_callback,
/*user_data=*/invocation.get());

// Regardless of status, the context reference we were holding is no longer
// needed. Drop it on the floor.
// Regardless of status, the context reference we were holding is no
// longer needed. Drop it on the floor.
iree::vm_context_ptr::steal_reference(owned_context);

// On success, then the complete callback takes ownership of the invocation,
// so we release it here and return. We have to treat the invocation as
// possibly deallocated at this point, since the async invocation may have
// finished already.
// On success, then the complete callback takes ownership of the
// invocation, so we release it here and return. We have to treat
// the invocation as possibly deallocated at this point, since the
// async invocation may have finished already.
if (iree_status_is_ok(status)) {
invocation.release();
} else if (failure_future) {
Expand Down Expand Up @@ -346,4 +414,18 @@ iree::vm_opaque_ref ProgramInvocation::result_ref(iree_host_size_t i) {
return out_value;
}

iree_hal_fence_t *ProgramInvocation::wait_fence() {
if (!wait_fence_) {
wait_fence_ = scope_->scheduler().NewFence();
}
return wait_fence_.get();
}

iree_hal_fence_t *ProgramInvocation::signal_fence() {
if (!signal_fence_) {
signal_fence_ = scope_->scheduler().NewFence();
}
return signal_fence_.get();
}

} // namespace shortfin::local
31 changes: 28 additions & 3 deletions libshortfin/src/shortfin/local/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ namespace shortfin::local {
class SHORTFIN_API Scope;
class SHORTFIN_API System;

enum class ProgramInvocationModel {
// Uses the coarse-fences invocation model. In this model, the last two
// arguments are a wait and signal fence, which are used for function-level
// scheduling.
COARSE_FENCES,
// The function was not annotated with an invocation model.
NONE,
// The function is not annotated or is simple/synchronous.
UNKNOWN,
};

// State related to making an invocation of a function on a program.
//
// Since ownership of this object is transferred to the loop/callback and
Expand Down Expand Up @@ -52,7 +63,8 @@ class SHORTFIN_API ProgramInvocation {
using Future = TypedFuture<ProgramInvocation::Ptr>;

static Ptr New(std::shared_ptr<Scope> scope, iree::vm_context_ptr vm_context,
iree_vm_function_t &vm_function);
iree_vm_function_t &vm_function,
ProgramInvocationModel invocation_model);
ProgramInvocation(const ProgramInvocation &) = delete;
ProgramInvocation &operator=(const ProgramInvocation &) = delete;
ProgramInvocation &operator=(ProgramInvocation &&) = delete;
Expand All @@ -67,6 +79,11 @@ class SHORTFIN_API ProgramInvocation {
// The scope this invocation was scheduled against.
Scope *scope() const { return scope_.get(); }

// Access the wait and signal fences for the invocation. These are created
// on the fly as needed.
iree_hal_fence_t *wait_fence();
iree_hal_fence_t *signal_fence();

// Adds a marshalable argument with a configurable concurrency barrier.
void AddArg(ProgramInvocationMarshalable &marshalable,
ProgramResourceBarrier barrier = ProgramResourceBarrier::READ);
Expand Down Expand Up @@ -101,6 +118,7 @@ class SHORTFIN_API ProgramInvocation {
// Context is retained upon construction and released when scheduled.
iree_vm_context_t *context;
iree_vm_function_t function;
ProgramInvocationModel invocation_model;
iree_vm_list_t *arg_list = nullptr;
};
union State {
Expand All @@ -113,20 +131,23 @@ class SHORTFIN_API ProgramInvocation {
std::shared_ptr<Scope> scope_;
iree_vm_list_t *result_list_ = nullptr;
std::optional<Future> future_;
iree::hal_fence_ptr wait_fence_;
iree::hal_fence_ptr signal_fence_;
bool scheduled_ = false;
};

// References a function in a Program.
class SHORTFIN_API ProgramFunction {
public:
ProgramFunction(iree::vm_context_ptr vm_context,
iree_vm_function_t vm_function)
: vm_context_(std::move(vm_context)), vm_function_(vm_function) {}
iree_vm_function_t vm_function,
std::optional<ProgramInvocationModel> invocation_model = {});

operator bool() const { return vm_context_; }

std::string_view name() const;
std::string_view calling_convention() const;
ProgramInvocationModel invocation_model() const { return invocation_model_; }

ProgramInvocation::Ptr CreateInvocation(std::shared_ptr<Scope> scope);

Expand All @@ -136,9 +157,13 @@ class SHORTFIN_API ProgramFunction {
operator iree_vm_function_t &() { return vm_function_; }

private:
static ProgramInvocationModel GetInvocationModelFromFunction(
iree_vm_function_t &f);

// The context that this function was resolved against.
iree::vm_context_ptr vm_context_;
iree_vm_function_t vm_function_;
ProgramInvocationModel invocation_model_;
friend class Program;
};

Expand Down
7 changes: 7 additions & 0 deletions libshortfin/src/shortfin/local/scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,11 @@ void Scheduler::Flush() {
}
}

iree::hal_fence_ptr Scheduler::NewFence() {
iree::hal_fence_ptr fence;
iree_hal_fence_create(semaphore_count_, system_.host_allocator(),
fence.for_output());
return fence;
}

} // namespace shortfin::local::detail
4 changes: 4 additions & 0 deletions libshortfin/src/shortfin/local/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ class SHORTFIN_API Scheduler {
new TimelineResource(std::move(scope), semaphore_count_));
}

// Creates a new fence with capacity for all semaphores that are extant at
// the point of the call.
iree::hal_fence_ptr NewFence();

System &system() { return system_; }

private:
Expand Down

0 comments on commit 5938ff9

Please sign in to comment.