Skip to content

Commit

Permalink
Fast server startup, using file_handle to load model param files (#840)
Browse files Browse the repository at this point in the history
# Description

Currently, when we load model parameter files, we load the entire
contents of the files into memory, then mmap that data to the devices.
This would cause very long server startup times. For example, 70b (~130
GB) took me 10 minutes to start the server, and 405b (~750 GB) took me
over `5 hours` to start the server.

As an alternative, this PR uses `iree_io_file_handle_open` to obtain a
handle to the parameter files, then streams that data to the devices,
insteading of mmaping it. After this change, we are able to start the
server for `70b` and `405b` within seconds.

We default to the new method and add a private function `LoadMmap` for
cases where `mmap == true`. This should improve the startup time for
both `LLM` and `SDXL`, especially when loading large files.
  • Loading branch information
stbaione authored Jan 17, 2025
1 parent d049625 commit 0c6d061
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
6 changes: 3 additions & 3 deletions shortfin/python/lib_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ void BindLocal(py::module_ &m) {
// Methods not on System but on child objects, taking System as an arg.
// Emitted here for convenience.
.def("load_module", &local::ProgramModule::Load, py::arg("path"),
py::arg("mmap") = true);
py::arg("mmap") = false);

// Support classes.
py::class_<local::Node>(m, "Node")
Expand Down Expand Up @@ -731,7 +731,7 @@ void BindLocal(py::module_ &m) {
.def_prop_ro("exports", &local::ProgramModule::exports)
.def("__repr__", &local::ProgramModule::to_s)
.def_static("load", &local::ProgramModule::Load, py::arg("system"),
py::arg("path"), py::arg("mmap") = true)
py::arg("path"), py::arg("mmap") = false)
.def_static(
"parameter_provider",
[](local::System &system, py::args params) {
Expand Down Expand Up @@ -817,7 +817,7 @@ void BindLocal(py::module_ &m) {
},
py::arg("file_path"), py::arg("format") = std::string_view(),
py::arg("readable") = true, py::arg("writable") = false,
py::arg("mmap") = true);
py::arg("mmap") = false);

struct DevicesSet {
DevicesSet(py::object fiber_obj, std::optional<size_t> index = {})
Expand Down
26 changes: 21 additions & 5 deletions shortfin/src/shortfin/local/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -633,13 +633,29 @@ void StaticProgramParameters::Load(std::filesystem::path file_path,
options.format = file_path.extension().string();
}

// Open file.
iree_file_read_flags_t read_flags = IREE_FILE_READ_FLAG_DEFAULT;
if (options.mmap) {
read_flags = IREE_FILE_READ_FLAG_MMAP;
} else {
read_flags = IREE_FILE_READ_FLAG_PRELOAD;
this->LoadMmap(file_path, options);
return;
}

auto file_path_string = file_path.string();
const iree_string_view_t path =
iree_make_cstring_view(file_path_string.c_str());
iree_io_file_handle_t *file_handle = NULL;
SHORTFIN_THROW_IF_ERROR(iree_io_file_handle_open(
IREE_IO_FILE_MODE_READ, path, host_allocator_, &file_handle));

// Parse.
SHORTFIN_THROW_IF_ERROR(
iree_io_parse_file_index(to_iree_string_view(options.format), file_handle,
index_.get(), host_allocator_));
}

void StaticProgramParameters::LoadMmap(std::filesystem::path file_path,
LoadOptions options) {
SHORTFIN_TRACE_SCOPE_NAMED("StaticProgramParameters::LoadMmap");

iree_file_read_flags_t read_flags = IREE_FILE_READ_FLAG_MMAP;
iree_file_contents_t *file_contents = nullptr;
SHORTFIN_THROW_IF_ERROR(iree_file_read_contents(
file_path.string().c_str(), read_flags, host_allocator_, &file_contents));
Expand Down
6 changes: 4 additions & 2 deletions shortfin/src/shortfin/local/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class SHORTFIN_API ProgramModule {

// Loads a dynamic bytecode module (VMFB) from a path on the file system.
static ProgramModule Load(System &system, const std::filesystem::path &path,
bool mmap = true);
bool mmap = false);

// Creates a ProgramModule that will provide the given list of parameters
// to modules loaded after it. In IREE parlance, this produces an
Expand Down Expand Up @@ -386,7 +386,7 @@ class SHORTFIN_API StaticProgramParameters : public BaseProgramParameters {
// Whether the backing file can be written.
bool writable = false;
// Whether to mmap the file.
bool mmap = true;
bool mmap = false;
};
// Load parameters from a supported file format, applying no name
// transformation.
Expand All @@ -396,6 +396,8 @@ class SHORTFIN_API StaticProgramParameters : public BaseProgramParameters {
private:
iree_allocator_t host_allocator_;
iree::io_parameter_index_ptr index_;

void LoadMmap(std::filesystem::path file_path, LoadOptions options);
};

namespace detail {
Expand Down

0 comments on commit 0c6d061

Please sign in to comment.