From 188ff8f5fe1edac7276d44560bb17527a5cb298b Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 13 Jan 2025 08:43:49 -0800 Subject: [PATCH] simplify server management a little further --- .../llm/server_management.py | 75 ++++++++++++++----- .../llm/shortfin/conftest.py | 11 ++- 2 files changed, 61 insertions(+), 25 deletions(-) diff --git a/app_tests/integration_tests/llm/server_management.py b/app_tests/integration_tests/llm/server_management.py index cf7211655..36c7b31b2 100644 --- a/app_tests/integration_tests/llm/server_management.py +++ b/app_tests/integration_tests/llm/server_management.py @@ -2,7 +2,7 @@ import json import socket from contextlib import closing -from dataclasses import dataclass, field +from dataclasses import dataclass import subprocess import time import requests @@ -18,16 +18,40 @@ class ServerConfig: """Configuration for server instance.""" - port: int artifacts: ModelArtifacts device_settings: DeviceSettings - - # things we need to write to config prefix_sharing_algorithm: str = "none" -class ServerManager: - """Manages server lifecycle and configuration.""" +class ServerInstance: + """An instance of the shortfin llm inference server. + + Example usage: + + ``` + from shortfin_apps.llm.server_management import ServerInstance, ServerConfig + # Create and start server + server = Server(config=ServerConfig( + artifacts=model_artifacts, + device_settings=device_settings, + prefix_sharing_algorithm="none" + )) + + server.start() # This starts the server and waits for it to be ready + + # Use the server + print(f"Server running on port {server.port}") + + # Cleanup when done + server.stop() + ``` + """ + + def __init__(self, config: ServerConfig): + self.config = config + self.process: Optional[subprocess.Popen] = None + self.port: Optional[int] = None + self.config_path: Optional[Path] = None @staticmethod def find_available_port() -> int: @@ -37,9 +61,6 @@ def find_available_port() -> int: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] - def __init__(self, config: ServerConfig): - self.config = config - def _write_config(self) -> Path: """Creates server config by extending the exported model config.""" # TODO: eliminate this by moving prefix sharing algorithm to be a cmdline arg of server.py @@ -59,31 +80,47 @@ def _write_config(self) -> Path: json.dump(config, f) return server_config_path - def start(self) -> subprocess.Popen: + def start(self) -> None: """Starts the server process.""" - config_path = self._write_config() + if self.process is not None: + raise RuntimeError("Server is already running") + + self.config_path = self._write_config() + self.port = self.find_available_port() + cmd = [ sys.executable, "-m", "shortfin_apps.llm.server", f"--tokenizer_json={self.config.artifacts.tokenizer_path}", - f"--model_config={config_path}", + f"--model_config={self.config_path}", f"--vmfb={self.config.artifacts.vmfb_path}", f"--parameters={self.config.artifacts.weights_path}", - f"--port={self.config.port}", + f"--port={self.port}", ] cmd.extend(self.config.device_settings.server_flags) - process = subprocess.Popen(cmd) - self._wait_for_server(timeout=10) - return process - def _wait_for_server(self, timeout: int = 10): - """Waits for server to be ready.""" + self.process = subprocess.Popen(cmd) + self.wait_for_ready() + + def wait_for_ready(self, timeout: int = 10) -> None: + """Waits for server to be ready and responding to health checks.""" + if self.port is None: + raise RuntimeError("Server hasn't been started") + start = time.time() while time.time() - start < timeout: try: - requests.get(f"http://localhost:{self.config.port}/health") + requests.get(f"http://localhost:{self.port}/health") return except requests.exceptions.ConnectionError: time.sleep(1) raise TimeoutError(f"Server failed to start within {timeout} seconds") + + def stop(self) -> None: + """Stops the server process.""" + if self.process is not None and self.process.poll() is None: + self.process.terminate() + self.process.wait() + self.process = None + self.port = None diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py index d58b604bd..18fb4f3ed 100644 --- a/app_tests/integration_tests/llm/shortfin/conftest.py +++ b/app_tests/integration_tests/llm/shortfin/conftest.py @@ -10,7 +10,7 @@ AzureConfig, ModelArtifacts, ) -from ..server_management import ServerManager, ServerConfig +from ..server_management import ServerInstance, ServerConfig from .. import device_settings # Example model configurations @@ -77,16 +77,15 @@ def server(model_artifacts, request): model_config = TEST_MODELS[model_id] server_config = ServerConfig( - port=ServerManager.find_available_port(), artifacts=model_artifacts, device_settings=model_config.device_settings, prefix_sharing_algorithm=request.param.get("prefix_sharing", "none"), ) - server_manager = ServerManager(server_config) - process = server_manager.start() - - yield process, server_config.port + server_instance = ServerInstance(server_config) + server_instance.start() + process, port = server_instance.process, server_instance.port + yield process, port process.terminate() process.wait()