Skip to content

Commit

Permalink
Merge branch 'main' into add-tmp-ppl-test
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam authored Jan 17, 2025
2 parents 7e7dd83 + 1f50538 commit 6627382
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 35 deletions.
6 changes: 3 additions & 3 deletions shortfin/python/lib_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ void BindLocal(py::module_ &m) {
// Methods not on System but on child objects, taking System as an arg.
// Emitted here for convenience.
.def("load_module", &local::ProgramModule::Load, py::arg("path"),
py::arg("mmap") = true);
py::arg("mmap") = false);

// Support classes.
py::class_<local::Node>(m, "Node")
Expand Down Expand Up @@ -731,7 +731,7 @@ void BindLocal(py::module_ &m) {
.def_prop_ro("exports", &local::ProgramModule::exports)
.def("__repr__", &local::ProgramModule::to_s)
.def_static("load", &local::ProgramModule::Load, py::arg("system"),
py::arg("path"), py::arg("mmap") = true)
py::arg("path"), py::arg("mmap") = false)
.def_static(
"parameter_provider",
[](local::System &system, py::args params) {
Expand Down Expand Up @@ -817,7 +817,7 @@ void BindLocal(py::module_ &m) {
},
py::arg("file_path"), py::arg("format") = std::string_view(),
py::arg("readable") = true, py::arg("writable") = false,
py::arg("mmap") = true);
py::arg("mmap") = false);

struct DevicesSet {
DevicesSet(py::object fiber_obj, std::optional<size_t> index = {})
Expand Down
10 changes: 0 additions & 10 deletions shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ class PageInfo:

index: int
pool: PagePool
token_offset: int # Offset within the page
token_count: int # Number of tokens stored in this page
writing: bool = False
read_ref_count: int = 0 # Number of threads that still need to read this page. When this reaches 0, page is eligible for release


@dataclass
Expand Down Expand Up @@ -80,8 +76,6 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig
PageInfo(
index=i,
pool=self,
token_offset=0,
token_count=0,
)
for i in range(self.config.alloc_page_count)
]
Expand Down Expand Up @@ -127,7 +121,6 @@ def copy_page(self, src_page: PageInfo) -> PageInfo:
Args:
src_page: Source page to copy from
token_count: Optional number of tokens to copy. If None, copies all tokens.
Returns:
New PageInfo containing the copied data
Expand All @@ -145,9 +138,6 @@ def copy_page(self, src_page: PageInfo) -> PageInfo:
# Copy the data
dst_view.copy_from(src_view)

# Setup destination page metadata
dst_page.token_offset = 0 # Always start at beginning of new page

return dst_page

def __repr__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,6 @@ def __init__(self, page_pool: PagePool, tokens_per_page: int):
dummy_page = PageInfo(
index=0, # Root uses reserved index 0
pool=self.page_pool,
token_offset=0,
token_count=0,
)
self.root = TrieNode(tokens=tuple(), page=dummy_page)
self.leaves: Set[TrieNode] = set()
Expand Down
1 change: 1 addition & 0 deletions shortfin/python/shortfin_apps/sd/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest":
"steps",
"guidance_scale",
"seed",
"input_ids",
]
rec_inputs = {}
for item in gen_inputs:
Expand Down
15 changes: 9 additions & 6 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,15 @@ async def _prepare(self, device, requests):
# Tokenize prompts and negative prompts. We tokenize in bs1 for now and join later.
input_ids_list = []
neg_ids_list = []
for tokenizer in self.service.tokenizers:
input_ids = tokenizer.encode(request.prompt)
input_ids_list.append(input_ids)
neg_ids = tokenizer.encode(request.neg_prompt)
neg_ids_list.append(neg_ids)
ids_list = [*input_ids_list, *neg_ids_list]
ids_list = request.input_ids
# Tokenize the prompts if the request does not hold input_ids.
if ids_list is None:
for tokenizer in self.service.tokenizers:
input_ids = tokenizer.encode(request.prompt)
input_ids_list.append(input_ids)
neg_ids = tokenizer.encode(request.neg_prompt)
neg_ids_list.append(neg_ids)
ids_list = [*input_ids_list, *neg_ids_list]

request.input_ids = ids_list

Expand Down
26 changes: 21 additions & 5 deletions shortfin/src/shortfin/local/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -633,13 +633,29 @@ void StaticProgramParameters::Load(std::filesystem::path file_path,
options.format = file_path.extension().string();
}

// Open file.
iree_file_read_flags_t read_flags = IREE_FILE_READ_FLAG_DEFAULT;
if (options.mmap) {
read_flags = IREE_FILE_READ_FLAG_MMAP;
} else {
read_flags = IREE_FILE_READ_FLAG_PRELOAD;
this->LoadMmap(file_path, options);
return;
}

auto file_path_string = file_path.string();
const iree_string_view_t path =
iree_make_cstring_view(file_path_string.c_str());
iree_io_file_handle_t *file_handle = NULL;
SHORTFIN_THROW_IF_ERROR(iree_io_file_handle_open(
IREE_IO_FILE_MODE_READ, path, host_allocator_, &file_handle));

// Parse.
SHORTFIN_THROW_IF_ERROR(
iree_io_parse_file_index(to_iree_string_view(options.format), file_handle,
index_.get(), host_allocator_));
}

void StaticProgramParameters::LoadMmap(std::filesystem::path file_path,
LoadOptions options) {
SHORTFIN_TRACE_SCOPE_NAMED("StaticProgramParameters::LoadMmap");

iree_file_read_flags_t read_flags = IREE_FILE_READ_FLAG_MMAP;
iree_file_contents_t *file_contents = nullptr;
SHORTFIN_THROW_IF_ERROR(iree_file_read_contents(
file_path.string().c_str(), read_flags, host_allocator_, &file_contents));
Expand Down
6 changes: 4 additions & 2 deletions shortfin/src/shortfin/local/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class SHORTFIN_API ProgramModule {

// Loads a dynamic bytecode module (VMFB) from a path on the file system.
static ProgramModule Load(System &system, const std::filesystem::path &path,
bool mmap = true);
bool mmap = false);

// Creates a ProgramModule that will provide the given list of parameters
// to modules loaded after it. In IREE parlance, this produces an
Expand Down Expand Up @@ -386,7 +386,7 @@ class SHORTFIN_API StaticProgramParameters : public BaseProgramParameters {
// Whether the backing file can be written.
bool writable = false;
// Whether to mmap the file.
bool mmap = true;
bool mmap = false;
};
// Load parameters from a supported file format, applying no name
// transformation.
Expand All @@ -396,6 +396,8 @@ class SHORTFIN_API StaticProgramParameters : public BaseProgramParameters {
private:
iree_allocator_t host_allocator_;
iree::io_parameter_index_ptr index_;

void LoadMmap(std::filesystem::path file_path, LoadOptions options);
};

namespace detail {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class MockPagePool(PagePool):
def __init__(self, total_pages: int):
self._queue = queue.Queue()
for i in range(total_pages):
page = PageInfo(index=i, pool=self, token_offset=0, token_count=0)
page = PageInfo(index=i, pool=self)
self._queue.put(page)

def acquire_free_pages(self, count: int) -> List[PageInfo]:
Expand Down
2 changes: 0 additions & 2 deletions tuner/examples/simple/simple_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def main():
simple_tuner.benchmark_flags = ["--input=1", "--benchmark_repetitions=3"]
top_candidates = libtuner.benchmark(
args,
path_config,
compiled_candidates,
candidate_trackers,
simple_tuner,
Expand Down Expand Up @@ -159,7 +158,6 @@ def main():
simple_tuner.benchmark_timeout = 60
top_model_candidates = libtuner.benchmark(
args,
path_config,
compiled_model_candidates,
candidate_trackers,
simple_tuner,
Expand Down
4 changes: 2 additions & 2 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def strip_compilation_info(input_path: Path) -> str:
return result.process_res.stdout


def main():
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("input", help="Input mlir file", type=str)
parser.add_argument(
Expand Down Expand Up @@ -369,4 +369,4 @@ def main():


if __name__ == "__main__":
args = main()
main()
1 change: 0 additions & 1 deletion tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,6 @@ def get_speedup(result: BenchmarkResult) -> float:

def benchmark(
args: argparse.Namespace,
path_config: PathConfig,
compiled_candidates: list[int],
candidate_trackers: list[CandidateTracker],
tuning_client: TuningClient,
Expand Down
1 change: 0 additions & 1 deletion tuner/tuner/libtuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import argparse
import math
from subprocess import CompletedProcess
from unittest.mock import call, patch, MagicMock
from . import libtuner

Expand Down

0 comments on commit 6627382

Please sign in to comment.