From 0c6d061a6a3524b9b1703bf53bf0df068fe283d0 Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Fri, 17 Jan 2025 12:46:30 -0600 Subject: [PATCH] Fast server startup, using file_handle to load model param files (#840) # Description Currently, when we load model parameter files, we load the entire contents of the files into memory, then mmap that data to the devices. This would cause very long server startup times. For example, 70b (~130 GB) took me 10 minutes to start the server, and 405b (~750 GB) took me over `5 hours` to start the server. As an alternative, this PR uses `iree_io_file_handle_open` to obtain a handle to the parameter files, then streams that data to the devices, insteading of mmaping it. After this change, we are able to start the server for `70b` and `405b` within seconds. We default to the new method and add a private function `LoadMmap` for cases where `mmap == true`. This should improve the startup time for both `LLM` and `SDXL`, especially when loading large files. --- shortfin/python/lib_ext.cc | 6 +++--- shortfin/src/shortfin/local/program.cc | 26 +++++++++++++++++++++----- shortfin/src/shortfin/local/program.h | 6 ++++-- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/shortfin/python/lib_ext.cc b/shortfin/python/lib_ext.cc index c668e6a8b..af7aba767 100644 --- a/shortfin/python/lib_ext.cc +++ b/shortfin/python/lib_ext.cc @@ -672,7 +672,7 @@ void BindLocal(py::module_ &m) { // Methods not on System but on child objects, taking System as an arg. // Emitted here for convenience. .def("load_module", &local::ProgramModule::Load, py::arg("path"), - py::arg("mmap") = true); + py::arg("mmap") = false); // Support classes. py::class_(m, "Node") @@ -731,7 +731,7 @@ void BindLocal(py::module_ &m) { .def_prop_ro("exports", &local::ProgramModule::exports) .def("__repr__", &local::ProgramModule::to_s) .def_static("load", &local::ProgramModule::Load, py::arg("system"), - py::arg("path"), py::arg("mmap") = true) + py::arg("path"), py::arg("mmap") = false) .def_static( "parameter_provider", [](local::System &system, py::args params) { @@ -817,7 +817,7 @@ void BindLocal(py::module_ &m) { }, py::arg("file_path"), py::arg("format") = std::string_view(), py::arg("readable") = true, py::arg("writable") = false, - py::arg("mmap") = true); + py::arg("mmap") = false); struct DevicesSet { DevicesSet(py::object fiber_obj, std::optional index = {}) diff --git a/shortfin/src/shortfin/local/program.cc b/shortfin/src/shortfin/local/program.cc index 0eefb572f..9744c1691 100644 --- a/shortfin/src/shortfin/local/program.cc +++ b/shortfin/src/shortfin/local/program.cc @@ -633,13 +633,29 @@ void StaticProgramParameters::Load(std::filesystem::path file_path, options.format = file_path.extension().string(); } - // Open file. - iree_file_read_flags_t read_flags = IREE_FILE_READ_FLAG_DEFAULT; if (options.mmap) { - read_flags = IREE_FILE_READ_FLAG_MMAP; - } else { - read_flags = IREE_FILE_READ_FLAG_PRELOAD; + this->LoadMmap(file_path, options); + return; } + + auto file_path_string = file_path.string(); + const iree_string_view_t path = + iree_make_cstring_view(file_path_string.c_str()); + iree_io_file_handle_t *file_handle = NULL; + SHORTFIN_THROW_IF_ERROR(iree_io_file_handle_open( + IREE_IO_FILE_MODE_READ, path, host_allocator_, &file_handle)); + + // Parse. + SHORTFIN_THROW_IF_ERROR( + iree_io_parse_file_index(to_iree_string_view(options.format), file_handle, + index_.get(), host_allocator_)); +} + +void StaticProgramParameters::LoadMmap(std::filesystem::path file_path, + LoadOptions options) { + SHORTFIN_TRACE_SCOPE_NAMED("StaticProgramParameters::LoadMmap"); + + iree_file_read_flags_t read_flags = IREE_FILE_READ_FLAG_MMAP; iree_file_contents_t *file_contents = nullptr; SHORTFIN_THROW_IF_ERROR(iree_file_read_contents( file_path.string().c_str(), read_flags, host_allocator_, &file_contents)); diff --git a/shortfin/src/shortfin/local/program.h b/shortfin/src/shortfin/local/program.h index 450b29736..f83b83cbd 100644 --- a/shortfin/src/shortfin/local/program.h +++ b/shortfin/src/shortfin/local/program.h @@ -265,7 +265,7 @@ class SHORTFIN_API ProgramModule { // Loads a dynamic bytecode module (VMFB) from a path on the file system. static ProgramModule Load(System &system, const std::filesystem::path &path, - bool mmap = true); + bool mmap = false); // Creates a ProgramModule that will provide the given list of parameters // to modules loaded after it. In IREE parlance, this produces an @@ -386,7 +386,7 @@ class SHORTFIN_API StaticProgramParameters : public BaseProgramParameters { // Whether the backing file can be written. bool writable = false; // Whether to mmap the file. - bool mmap = true; + bool mmap = false; }; // Load parameters from a supported file format, applying no name // transformation. @@ -396,6 +396,8 @@ class SHORTFIN_API StaticProgramParameters : public BaseProgramParameters { private: iree_allocator_t host_allocator_; iree::io_parameter_index_ptr index_; + + void LoadMmap(std::filesystem::path file_path, LoadOptions options); }; namespace detail {