Skip to content

Commit

Permalink
simplify server management a little further
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Jan 13, 2025
1 parent ebfd09e commit 188ff8f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 25 deletions.
75 changes: 56 additions & 19 deletions app_tests/integration_tests/llm/server_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import socket
from contextlib import closing
from dataclasses import dataclass, field
from dataclasses import dataclass
import subprocess
import time
import requests
Expand All @@ -18,16 +18,40 @@
class ServerConfig:
"""Configuration for server instance."""

port: int
artifacts: ModelArtifacts
device_settings: DeviceSettings

# things we need to write to config
prefix_sharing_algorithm: str = "none"


class ServerManager:
"""Manages server lifecycle and configuration."""
class ServerInstance:
"""An instance of the shortfin llm inference server.
Example usage:
```
from shortfin_apps.llm.server_management import ServerInstance, ServerConfig
# Create and start server
server = Server(config=ServerConfig(
artifacts=model_artifacts,
device_settings=device_settings,
prefix_sharing_algorithm="none"
))
server.start() # This starts the server and waits for it to be ready
# Use the server
print(f"Server running on port {server.port}")
# Cleanup when done
server.stop()
```
"""

def __init__(self, config: ServerConfig):
self.config = config
self.process: Optional[subprocess.Popen] = None
self.port: Optional[int] = None
self.config_path: Optional[Path] = None

@staticmethod
def find_available_port() -> int:
Expand All @@ -37,9 +61,6 @@ def find_available_port() -> int:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]

def __init__(self, config: ServerConfig):
self.config = config

def _write_config(self) -> Path:
"""Creates server config by extending the exported model config."""
# TODO: eliminate this by moving prefix sharing algorithm to be a cmdline arg of server.py
Expand All @@ -59,31 +80,47 @@ def _write_config(self) -> Path:
json.dump(config, f)
return server_config_path

def start(self) -> subprocess.Popen:
def start(self) -> None:
"""Starts the server process."""
config_path = self._write_config()
if self.process is not None:
raise RuntimeError("Server is already running")

self.config_path = self._write_config()
self.port = self.find_available_port()

cmd = [
sys.executable,
"-m",
"shortfin_apps.llm.server",
f"--tokenizer_json={self.config.artifacts.tokenizer_path}",
f"--model_config={config_path}",
f"--model_config={self.config_path}",
f"--vmfb={self.config.artifacts.vmfb_path}",
f"--parameters={self.config.artifacts.weights_path}",
f"--port={self.config.port}",
f"--port={self.port}",
]
cmd.extend(self.config.device_settings.server_flags)
process = subprocess.Popen(cmd)
self._wait_for_server(timeout=10)
return process

def _wait_for_server(self, timeout: int = 10):
"""Waits for server to be ready."""
self.process = subprocess.Popen(cmd)
self.wait_for_ready()

def wait_for_ready(self, timeout: int = 10) -> None:
"""Waits for server to be ready and responding to health checks."""
if self.port is None:
raise RuntimeError("Server hasn't been started")

start = time.time()
while time.time() - start < timeout:
try:
requests.get(f"http://localhost:{self.config.port}/health")
requests.get(f"http://localhost:{self.port}/health")
return
except requests.exceptions.ConnectionError:
time.sleep(1)
raise TimeoutError(f"Server failed to start within {timeout} seconds")

def stop(self) -> None:
"""Stops the server process."""
if self.process is not None and self.process.poll() is None:
self.process.terminate()
self.process.wait()
self.process = None
self.port = None
11 changes: 5 additions & 6 deletions app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
AzureConfig,
ModelArtifacts,
)
from ..server_management import ServerManager, ServerConfig
from ..server_management import ServerInstance, ServerConfig
from .. import device_settings

# Example model configurations
Expand Down Expand Up @@ -77,16 +77,15 @@ def server(model_artifacts, request):
model_config = TEST_MODELS[model_id]

server_config = ServerConfig(
port=ServerManager.find_available_port(),
artifacts=model_artifacts,
device_settings=model_config.device_settings,
prefix_sharing_algorithm=request.param.get("prefix_sharing", "none"),
)

server_manager = ServerManager(server_config)
process = server_manager.start()

yield process, server_config.port
server_instance = ServerInstance(server_config)
server_instance.start()
process, port = server_instance.process, server_instance.port
yield process, port

process.terminate()
process.wait()

0 comments on commit 188ff8f

Please sign in to comment.