Skip to content

Commit

Permalink
Zero-fill start positions instead of one-filling (#826)
Browse files Browse the repository at this point in the history
PRs in the history of this problem: #665, #723

#665 is supposed to fix a NaN cache corruption issue by 1-filling
seq_len instead of 0-filling. Its supposed to 1-fill seq_len for decode
and prefill, but I mistakenly 1-filled seq_len for decode only, and also
1-filled the start_position for decode instead of prefill seq_len.

#723 adds 1-filling for prefill, and this PR removes the mistaken
start_positions 1-filling for decode.

After this PR shortfin concurrent tests should be working properly.

Up next: a failing trie kv sharing test case.
  • Loading branch information
renxida authored Jan 15, 2025
1 parent cbcff3d commit e34ffec
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,6 @@ def test_basic_generation(self, server: tuple[Any, int]) -> None:
indirect=True,
)
@pytest.mark.parametrize("concurrent_requests", [2, 4, 8])
@pytest.mark.xfail(
raises=AccuracyValidationException,
reason="Concurreny issues in Shortfin batch processing",
)
def test_concurrent_generation(
self, server: tuple[Any, int], concurrent_requests: int
) -> None:
Expand Down
11 changes: 4 additions & 7 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,22 +388,19 @@ async def run(self):
seq_lens_host.copy_to(seq_lens)

# For decode, populate start_positions and seq_lens.
# paged_llm_v1 and export_paged_llm_v1 do some funky things with start_positions and seq_lens
# TODO: make them not so funky
if self.phase == InferencePhase.DECODE:
start_positions_host = start_positions.for_transfer()
with start_positions_host.map(discard=True) as m:
m.fill(
1
) # Pad unused requests. Must pad with nonzero value because division by 0 floods clobber page (page 0) in cache with NaN values.
m.fill(0)
m.items = [req.start_position for req in self.exec_requests]
start_positions_host.copy_to(start_positions)

seq_lens_host = seq_lens.for_transfer()
with seq_lens_host.map(discard=True) as m:
# Pad unused requests.
m.fill(
1
) # Pad unused requests. Must pad with nonzero value because division by 0 floods clobber page (page 0) in cache with NaN values.
1 # Must pad with a nonzero value because a division by 0 during softmax floods clobber page (page 0) in cache with NaN values.
)
m.items = [
req.start_position + len(req.input_token_ids)
for req in self.exec_requests
Expand Down

0 comments on commit e34ffec

Please sign in to comment.