Skip to content

Commit

Permalink
Make Program close over Scope.
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Sep 3, 2024
1 parent beca75b commit c1b03e8
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 109 deletions.
2 changes: 1 addition & 1 deletion libshortfin/bindings/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ set_target_properties(shortfin_python_extension
target_link_libraries(shortfin_python_extension
# TODO: This should be configurable as to whether we link to the static
# or dynamic version.
PRIVATE shortfin
PRIVATE shortfin-static
)

nanobind_add_stub(
Expand Down
30 changes: 13 additions & 17 deletions libshortfin/bindings/python/lib_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,8 @@ void PyAddProgramInvocationArg(py::capsule &inv_capsule, py::handle arg) {
}

local::ProgramInvocation::Future PyFunctionCall(local::ProgramFunction &self,
py::args args,
local::Scope &scope) {
auto inv = self.CreateInvocation(scope.shared_from_this());
py::args args) {
auto inv = self.CreateInvocation();
py::capsule inv_capsule(inv.get());
for (py::handle arg : args) {
PyAddProgramInvocationArg(inv_capsule, arg);
Expand Down Expand Up @@ -443,6 +442,15 @@ void BindLocal(py::module_ &m) {
.def("__repr__", &local::DeviceAffinity::to_s);

py::class_<local::Program>(m, "Program")
.def(py::new_([](std::span<const local::ProgramModule> modules,
local::Scope &scope, bool trace_execution) {
local::Program::Options options;
options.trace_execution = trace_execution;
return local::Program::Load(scope.shared_from_this(), modules,
std::move(options));
}),
py::arg("modules"), py::arg("scope"), py::kw_only(),
py::arg("trace_execution") = false)
.def_prop_ro("exports", &local::Program::exports)
.def("lookup_function", &local::Program::LookupRequiredFunction)
.def("__getitem__", &local::Program::LookupRequiredFunction);
Expand All @@ -451,10 +459,8 @@ void BindLocal(py::module_ &m) {
.def_prop_ro("calling_convention",
&local::ProgramFunction::calling_convention)
.def("invocation", &local::ProgramFunction::CreateInvocation,
py::kw_only(), py::arg("scope"),
DOCSTRING_PROGRAM_FUNCTION_INVOCATION)
.def("__call__", PyFunctionCall, py::arg("args"), py::kw_only(),
py::arg("scope"))
.def("__call__", PyFunctionCall, py::arg("args"))
.def("__repr__", &local::ProgramFunction::to_s);
py::class_<local::ProgramModule>(m, "ProgramModule")
.def_prop_ro("exports", &local::ProgramModule::exports)
Expand Down Expand Up @@ -538,17 +544,7 @@ void BindLocal(py::module_ &m) {
[](local::Scope &self, py::args args) {
return CastDeviceAffinity(self, args);
},
py::rv_policy::reference_internal)
.def(
"load_unbound_program",
[](local::Scope &scope, std::span<const local::ProgramModule> modules,
bool trace_execution) {
local::Program::Options options;
options.trace_execution = trace_execution;
return scope.LoadUnboundProgram(modules, std::move(options));
},
py::arg("modules"), py::kw_only(),
py::arg("trace_execution") = false);
py::rv_policy::reference_internal);

py::class_<local::ScopedDevice>(m, "ScopedDevice")
.def_prop_ro("scope", &local::ScopedDevice::scope,
Expand Down
22 changes: 10 additions & 12 deletions libshortfin/examples/python/mobilenet_server/inference_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,24 @@ async def run(self):
print("host_staging =", self.host_staging)
self.device_input.copy_from(self.host_staging)

# Explicit invocation object.
# inv = self.main_function.invocation(scope=self.scope)
# inv.add_arg(self.device_input)
# results = await inv.invoke()
# print("results:", results)

# Simple call. Note that the await here is merely awaiting the
# result being *available* (i.e. that the VM coroutine has
# completed) but does not indicate that the result is ready.
(result1,) = await self.main_function(self.device_input, scope=self.scope)
(result2,) = await self.main_function(self.device_input, scope=self.scope)
(result1,) = await self.main_function(self.device_input)
(result2,) = await self.main_function(self.device_input)

# TODO: Implement await on individual results. The accounting is
# there but currently we can only await on the device itself.
await self.device
print("Result 1:", result1)
print("Result 2:", result2)

# Explicit invocation object.
# inv = self.main_function.invocation(scope=self.scope)
# inv.add_arg(self.device_input)
# results = await inv.invoke()
# print("results:", results)

# Multiple invocations in parallel.
# all_results = await asyncio.gather(
# self.main_function(self.device_input, scope=self.scope),
Expand All @@ -80,7 +80,7 @@ async def run(self):

class Main:
def __init__(self, lsys: sf.System, home_dir: Path):
self.processes_per_worker = 1
self.processes_per_worker = 2
self.lsys = lsys
self.home_dir = home_dir
self.request_queue = lsys.create_queue("request")
Expand All @@ -92,10 +92,8 @@ async def start_scope(self, scope):
# Note that currently, program load is synchronous. But we do it
# in a task so we can await it in the future and let program loads
# overlap.
program = scope.load_unbound_program(
[self.program_module], trace_execution=False
)
for _ in range(self.processes_per_worker):
program = sf.Program([self.program_module], scope=scope)
self.processes.append(
InferenceProcess(program, self.request_queue, scope=scope).launch()
)
Expand Down
59 changes: 52 additions & 7 deletions libshortfin/src/shortfin/local/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "fmt/core.h"
#include "fmt/std.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/bytecode/module.h"
#include "shortfin/local/scope.h"
#include "shortfin/local/system.h"
Expand All @@ -33,9 +34,11 @@ void GetVmModuleExports(iree_vm_module_t *vm_module,
// -------------------------------------------------------------------------- //

ProgramFunction::ProgramFunction(
iree::vm_context_ptr vm_context, iree_vm_function_t vm_function,
std::shared_ptr<Scope> scope, iree::vm_context_ptr vm_context,
iree_vm_function_t vm_function,
std::optional<ProgramInvocationModel> invocation_model)
: vm_context_(std::move(vm_context)),
: scope_(std::move(scope)),
vm_context_(std::move(vm_context)),
vm_function_(vm_function),
invocation_model_(invocation_model
? *invocation_model
Expand Down Expand Up @@ -68,9 +71,8 @@ std::string_view ProgramFunction::calling_convention() const {
iree_vm_function_signature(&vm_function_).calling_convention);
}

ProgramInvocation::Ptr ProgramFunction::CreateInvocation(
std::shared_ptr<Scope> scope) {
return ProgramInvocation::New(std::move(scope), vm_context_, vm_function_,
ProgramInvocation::Ptr ProgramFunction::CreateInvocation() {
return ProgramInvocation::New(scope_, vm_context_, vm_function_,
invocation_model_);
}

Expand Down Expand Up @@ -133,6 +135,49 @@ std::vector<std::string> ProgramModule::exports() const {
// Program
// -------------------------------------------------------------------------- //

Program Program::Load(std::shared_ptr<Scope> scope,
std::span<const ProgramModule> modules, Options options) {
std::vector<iree_vm_module_t *> all_modules;
std::vector<iree_hal_device_t *> raw_devices;

// By default, bind all devices in the scope in order to the program.
for (Device *d : scope->raw_devices()) {
raw_devices.push_back(d->hal_device());
}

// Add a HAL module.
// TODO: at some point may want to change this to something similar to
// what the tooling does in iree_tooling_resolve_modules - it uses
// iree_vm_module_enumerate_dependencies to walk the dependencies and add the
// required modules only as needed. to start you could use it just to see if
// the hal is used, but as you add other module types for exposing sharkfin
// functionality (or module versions; iree_vm_module_dependency_t has the
// minimum version required so you can switch between them, and whether they
// are optional/required).
auto &system = scope->system();
iree::vm_module_ptr hal_module;
SHORTFIN_THROW_IF_ERROR(
iree_hal_module_create(system.vm_instance(), raw_devices.size(),
raw_devices.data(), IREE_HAL_MODULE_FLAG_NONE,
system.host_allocator(), hal_module.for_output()));
all_modules.push_back(hal_module);

// Add explicit modules.
for (auto &pm : modules) {
all_modules.push_back(pm.vm_module());
}

// Create the context.
iree::vm_context_ptr context;
iree_vm_context_flags_t flags = IREE_VM_CONTEXT_FLAG_CONCURRENT;
if (options.trace_execution) flags |= IREE_VM_CONTEXT_FLAG_TRACE_EXECUTION;
SHORTFIN_THROW_IF_ERROR(iree_vm_context_create_with_modules(
system.vm_instance(), flags, all_modules.size(), all_modules.data(),
system.host_allocator(), context.for_output()));

return Program(std::move(scope), std::move(context));
}

std::optional<ProgramFunction> Program::LookupFunction(std::string_view name) {
// By convention, we currently name our coarse-fences function variants
// as ending in "$async". These are the ones we want but it is inconvenient.
Expand All @@ -149,7 +194,7 @@ std::optional<ProgramFunction> Program::LookupFunction(std::string_view name) {
// TODO: Torch import is not setting the coarse-fences abi.model on
// its functions. Get it from there instead of just assuming based on
// name.
return ProgramFunction(vm_context_, f,
return ProgramFunction(scope_, vm_context_, f,
ProgramInvocationModel::COARSE_FENCES);
} else if (!iree_status_is_not_found(status)) {
SHORTFIN_THROW_IF_ERROR(status);
Expand All @@ -161,7 +206,7 @@ std::optional<ProgramFunction> Program::LookupFunction(std::string_view name) {
vm_context_, to_iree_string_view(name), &f);
if (iree_status_is_not_found(status)) return {};
SHORTFIN_THROW_IF_ERROR(status);
return ProgramFunction(vm_context_, f);
return ProgramFunction(scope_, vm_context_, f);
}

ProgramFunction Program::LookupRequiredFunction(std::string_view name) {
Expand Down
45 changes: 28 additions & 17 deletions libshortfin/src/shortfin/local/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <filesystem>
#include <optional>
#include <span>
#include <string_view>
#include <vector>

Expand Down Expand Up @@ -175,28 +176,29 @@ class SHORTFIN_API ProgramInvocation {
// References a function in a Program.
class SHORTFIN_API ProgramFunction {
public:
ProgramFunction(iree::vm_context_ptr vm_context,
iree_vm_function_t vm_function,
std::optional<ProgramInvocationModel> invocation_model = {});

operator bool() const { return vm_context_; }

std::string_view name() const;
std::string_view calling_convention() const;
ProgramInvocationModel invocation_model() const { return invocation_model_; }

ProgramInvocation::Ptr CreateInvocation(std::shared_ptr<Scope> scope);
ProgramInvocation::Ptr CreateInvocation();

std::string to_s() const;

operator iree_vm_context_t *() { return vm_context_.get(); }
operator iree_vm_function_t &() { return vm_function_; }

private:
ProgramFunction(std::shared_ptr<Scope> scope, iree::vm_context_ptr vm_context,
iree_vm_function_t vm_function,
std::optional<ProgramInvocationModel> invocation_model = {});

static ProgramInvocationModel GetInvocationModelFromFunction(
iree_vm_function_t &f);

// The context that this function was resolved against.
std::shared_ptr<Scope> scope_;
iree::vm_context_ptr vm_context_;
iree_vm_function_t vm_function_;
ProgramInvocationModel invocation_model_;
Expand Down Expand Up @@ -241,24 +243,31 @@ class SHORTFIN_API ProgramModule {
};

// Programs consist of ProgramModules instantiated together and capable of
// having functions invoked on them. While it is possible to construct
// programs that do not depend on device-associated state, the dominant
// use case is for programs that are compiled to operate against the device
// HAL with a list of concrete devices. Such programs are constructed from
// a Scope.
// having functions invoked on them. While the underlying programming model
// is a bit broader and can be exploited in various advanced way, generally,
// a program should be thought of as a fiber, and it is therefore bound to
// a Scope, which provides a logical thread of execution. By default, all
// invocations will take place in logical order (there are certain ways to
// violate this constraint safely that are provided for separately).
//
// While the concurrency model for programs is technically a bit broader, the
// intended use is for them to be interacted with on a single Worker in a
// non-blocking fashion. There are many advanced ways that programs can be
// constructed to straddle devices, scopes, and workers, but that is left as
// an advanced use case.
// The program will source any needed parameters from the System and it will
// make an effort to cache them for proper locality on individual devices
// (TODO: make this actually true).
class SHORTFIN_API Program {
public:
struct Options {
Options() {}

// Enables program-wide execution tracing (to stderr).
bool trace_execution = false;
};

// Loads a program attached to a scope with a list of user provided modules
// and options.
static Program Load(std::shared_ptr<Scope> scope,
std::span<const ProgramModule> modules,
Options options = {});

// Looks up a public function by fully qualified name (i.e. module.function).
// Returns nothing if not found.
std::optional<ProgramFunction> LookupFunction(std::string_view name);
Expand All @@ -271,8 +280,10 @@ class SHORTFIN_API Program {
std::vector<std::string> exports() const;

private:
explicit Program(iree::vm_context_ptr vm_context)
: vm_context_(std::move(vm_context)) {}
explicit Program(std::shared_ptr<Scope> scope,
iree::vm_context_ptr vm_context)
: scope_(std::move(scope)), vm_context_(std::move(vm_context)) {}
std::shared_ptr<Scope> scope_;
iree::vm_context_ptr vm_context_;
friend class Scope;
};
Expand Down
2 changes: 1 addition & 1 deletion libshortfin/src/shortfin/local/scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace shortfin::local::detail {

namespace {

std::string SummarizeFence(iree_hal_fence_t *fence) {
[[maybe_unused]] std::string SummarizeFence(iree_hal_fence_t *fence) {
if (!SHORTFIN_SCHED_LOG_ENABLED) {
return std::string();
}
Expand Down
43 changes: 0 additions & 43 deletions libshortfin/src/shortfin/local/scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <fmt/core.h>
#include <fmt/ranges.h>

#include "iree/modules/hal/module.h"
#include "shortfin/local/system.h"
#include "shortfin/support/logging.h"

Expand Down Expand Up @@ -90,48 +89,6 @@ std::vector<std::string_view> Scope::device_names() const {
return names;
}

Program Scope::LoadUnboundProgram(std::span<const ProgramModule> modules,
Program::Options options) {
std::vector<iree_vm_module_t *> all_modules;
std::vector<iree_hal_device_t *> raw_devices;

// By default, bind all devices in the scope in order to the program.
for (Device *d : devices_) {
raw_devices.push_back(d->hal_device());
}

// Add a HAL module.
// TODO: at some point may want to change this to something similar to
// what the tooling does in iree_tooling_resolve_modules - it uses
// iree_vm_module_enumerate_dependencies to walk the dependencies and add the
// required modules only as needed. to start you could use it just to see if
// the hal is used, but as you add other module types for exposing sharkfin
// functionality (or module versions; iree_vm_module_dependency_t has the
// minimum version required so you can switch between them, and whether they
// are optional/required).
iree::vm_module_ptr hal_module;
SHORTFIN_THROW_IF_ERROR(iree_hal_module_create(
system().vm_instance(), raw_devices.size(), raw_devices.data(),
IREE_HAL_MODULE_FLAG_NONE, system().host_allocator(),
hal_module.for_output()));
all_modules.push_back(hal_module);

// Add explicit modules.
for (auto &pm : modules) {
all_modules.push_back(pm.vm_module());
}

// Create the context.
iree::vm_context_ptr context;
iree_vm_context_flags_t flags = IREE_VM_CONTEXT_FLAG_CONCURRENT;
if (options.trace_execution) flags |= IREE_VM_CONTEXT_FLAG_TRACE_EXECUTION;
SHORTFIN_THROW_IF_ERROR(iree_vm_context_create_with_modules(
system().vm_instance(), flags, all_modules.size(), all_modules.data(),
system().host_allocator(), context.for_output()));

return Program(std::move(context));
}

// -------------------------------------------------------------------------- //
// ScopedDevice
// -------------------------------------------------------------------------- //
Expand Down
Loading

0 comments on commit c1b03e8

Please sign in to comment.