diff --git a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py index 7c64133a6..4e828c8c5 100644 --- a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py +++ b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py @@ -85,10 +85,6 @@ def test_basic_generation(self, server: tuple[Any, int]) -> None: indirect=True, ) @pytest.mark.parametrize("concurrent_requests", [2, 4, 8]) - @pytest.mark.xfail( - raises=AccuracyValidationException, - reason="Concurreny issues in Shortfin batch processing", - ) def test_concurrent_generation( self, server: tuple[Any, int], concurrent_requests: int ) -> None: diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index a1e57ce84..9f08e82b1 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -388,22 +388,19 @@ async def run(self): seq_lens_host.copy_to(seq_lens) # For decode, populate start_positions and seq_lens. - # paged_llm_v1 and export_paged_llm_v1 do some funky things with start_positions and seq_lens - # TODO: make them not so funky if self.phase == InferencePhase.DECODE: start_positions_host = start_positions.for_transfer() with start_positions_host.map(discard=True) as m: - m.fill( - 1 - ) # Pad unused requests. Must pad with nonzero value because division by 0 floods clobber page (page 0) in cache with NaN values. + m.fill(0) m.items = [req.start_position for req in self.exec_requests] start_positions_host.copy_to(start_positions) seq_lens_host = seq_lens.for_transfer() with seq_lens_host.map(discard=True) as m: + # Pad unused requests. m.fill( - 1 - ) # Pad unused requests. Must pad with nonzero value because division by 0 floods clobber page (page 0) in cache with NaN values. + 1 # Must pad with a nonzero value because a division by 0 during softmax floods clobber page (page 0) in cache with NaN values. + ) m.items = [ req.start_position + len(req.input_token_ids) for req in self.exec_requests