Skip to content

Commit

Permalink
Concurrent CPU Integration Tests + Reuse Model Artifacts (#655)
Browse files Browse the repository at this point in the history
# Description

There are two things that get added in this PR:
1. Reuse model artifacts when multiple tests in the same module request
export/compilation of the same artifacts
2. Add concurrency tests for `2, 4, and 8` requests at the same time.

# Reusing model aritfacts

Currently, our `cpu_llm_server_integration_tests` generate new model
artifacts for each tests, even when they are requesting the exact same
artifacts. This causes the tests to take much longer to run than they
should, and makes it harder to add more tests without drastically
increasing overall test time.

We add a static `MODEL_DIR_CACHE`, which is just a hashmap that stores
`{ request.params_hash: temporary_dir }`. If a test requests the same
artifacts as a previous test, we reuse the already existing artifacts,
instead of generating new ones.

# Adding concurrency tests

We recenly found a bug in concurrency with the Shortfin LLM Server. When
sending multiple requests at the same time, we end up with responses
that have incorrect tokens.

This adds basic concurrent integration tests for 2, 4, and 8 requests
sent in parallel. Currently, they are xfailed, but we will be able to
use these to validate our fix, when we get there, and ensure that we
don't have a regression in concurrency in the future. Will extend the
`periodic SGLang Integration tests` to further test concurrency on GPU,
with more complex prompts, but for a PR triggered test, this should
serve as a good guard.
  • Loading branch information
stbaione authored Dec 6, 2024
1 parent b4cc54c commit 2faadd2
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 71 deletions.
122 changes: 69 additions & 53 deletions app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import hashlib
import json
import logging
import os
from pathlib import Path
import pytest
import shutil

pytest.importorskip("transformers")
from ..utils import (
Expand All @@ -25,6 +23,8 @@

logger = logging.getLogger(__name__)

MODEL_DIR_CACHE = {}


@pytest.fixture(scope="module")
def model_test_dir(request, tmp_path_factory):
Expand All @@ -46,58 +46,73 @@ def model_test_dir(request, tmp_path_factory):
"Preparing model artifacts..." + start_log_group("Preparing model artifacts")
)

param_key = hashlib.md5(str(request.param).encode()).hexdigest()
if (directory := MODEL_DIR_CACHE.get(param_key)) is not None:
logger.info(
f"Reusing existing model artifacts directory: {directory}" + end_log_group()
)
yield MODEL_DIR_CACHE[param_key]
return

repo_id = request.param["repo_id"]
model_file = request.param["model_file"]
tokenizer_id = request.param["tokenizer_id"]
settings = request.param["settings"]
batch_sizes = request.param["batch_sizes"]
tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test")

# Download model if it doesn't exist
model_path = tmp_dir / model_file
download_huggingface_model(tmp_dir, repo_id, model_file)

# Set up tokenizer if it doesn't exist
download_tokenizer(tmp_dir, tokenizer_id)

# Export model
mlir_path = tmp_dir / "model.mlir"
config_path = tmp_dir / "config.json"
export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)

# Compile model
vmfb_path = tmp_dir / "model.vmfb"
compile_model(mlir_path, vmfb_path, settings)

logger.info("Model artifacts setup successfully" + end_log_group())
MODEL_DIR_CACHE[param_key] = tmp_dir
yield tmp_dir


@pytest.fixture(scope="module")
def write_config(request, model_test_dir):
batch_sizes = request.param["batch_sizes"]
prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"]

tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test")
hf_home = os.environ.get("HF_HOME", None)
hf_home = Path(hf_home) if hf_home is not None else tmp_dir
try:
# Download model if it doesn't exist
model_path = hf_home / model_file
download_huggingface_model(hf_home, repo_id, model_file)

# Set up tokenizer if it doesn't exist
download_tokenizer(hf_home, tokenizer_id)

# Export model
mlir_path = tmp_dir / "model.mlir"
config_path = tmp_dir / "config.json"
export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)

# Compile model
vmfb_path = tmp_dir / "model.vmfb"
compile_model(mlir_path, vmfb_path, settings)

# Write config
edited_config_path = tmp_dir / "edited_config.json"
config = {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 2048,
"attn_head_count": 32,
"attn_head_dim": 100,
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"transformer_block_count": 26,
"paged_kv_cache": {
"block_seq_stride": 16,
"device_block_count": 256,
"prefix_sharing_algorithm": prefix_sharing_algorithm,
},
}
logger.info(f"Saving edited config to: {edited_config_path}\n")
logger.info(f"Config: {json.dumps(config, indent=2)}")
with open(edited_config_path, "w") as f:
json.dump(config, f)
logger.info("Model artifacts setup successfully" + end_log_group())
yield hf_home, tmp_dir
finally:
shutil.rmtree(tmp_dir)
config_path = (
model_test_dir
/ f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json"
)

config = {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 2048,
"attn_head_count": 32,
"attn_head_dim": 100,
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"transformer_block_count": 26,
"paged_kv_cache": {
"block_seq_stride": 16,
"device_block_count": 256,
"prefix_sharing_algorithm": prefix_sharing_algorithm,
},
}
logger.info(f"Saving edited config to: {config_path}\n")
logger.info(f"Config: {json.dumps(config, indent=2)}")
with open(config_path, "w") as f:
json.dump(config, f)

yield config_path


@pytest.fixture(scope="module")
Expand All @@ -106,7 +121,7 @@ def available_port():


@pytest.fixture(scope="module")
def llm_server(request, model_test_dir, available_port):
def llm_server(request, model_test_dir, write_config, available_port):
"""Start the LLM server.
Args:
Expand All @@ -120,14 +135,15 @@ def llm_server(request, model_test_dir, available_port):
subprocess.Popen: The server process that was started.
"""
logger.info("Starting LLM server..." + start_log_group("Starting LLM server"))
hf_home, tmp_dir = model_test_dir
tmp_dir = model_test_dir
config_path = write_config

model_file = request.param["model_file"]
settings = request.param["settings"]

tokenizer_path = hf_home / "tokenizer.json"
config_path = tmp_dir / "edited_config.json"
tokenizer_path = tmp_dir / "tokenizer.json"
vmfb_path = tmp_dir / "model.vmfb"
parameters_path = hf_home / model_file
parameters_path = tmp_dir / model_file

# Start llm server
server_process = start_llm_server(
Expand Down
118 changes: 100 additions & 18 deletions app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import concurrent.futures
import logging
import os
import pytest
Expand Down Expand Up @@ -31,7 +32,7 @@
}


def do_generate(prompt, port):
def do_generate(prompt, port, concurrent_requests=1):
logger.info("Generating request...")
headers = {"Content-Type": "application/json"}
# Create a GenerateReqInput-like structure
Expand All @@ -48,22 +49,40 @@ def do_generate(prompt, port):
logger.info("Prompt text:")
logger.info(data["text"])
BASE_URL = f"http://localhost:{port}"
response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data)
logger.info(f"Generate endpoint status code: {response.status_code}")
if response.status_code == 200:
logger.info("Generated text:")
data = response.text
assert data.startswith("data: ")
data = data[6:]
assert data.endswith("\n\n")
data = data[:-2]
return data
else:
response.raise_for_status()

response_data = []
with concurrent.futures.ThreadPoolExecutor(
max_workers=concurrent_requests
) as executor:
futures = [
executor.submit(
lambda: requests.post(
f"{BASE_URL}/generate", headers=headers, json=data
)
)
for _ in range(concurrent_requests)
]
for future in concurrent.futures.as_completed(futures):
response = future.result()

logger.info(f"Generate endpoint status code: {response.status_code}")
if response.status_code == 200:
logger.info("Generated text:")
data = response.text
assert data.startswith("data: ")
data = data[6:]
assert data.endswith("\n\n")
data = data[:-2]
logger.info(data)
response_data.append(data)
else:
response.raise_for_status()

return response_data


@pytest.mark.parametrize(
"model_test_dir,llm_server",
"model_test_dir,write_config,llm_server",
[
pytest.param(
{
Expand All @@ -72,8 +91,8 @@ def do_generate(prompt, port):
"tokenizer_id": "openlm-research/open_llama_3b_v2",
"settings": CPU_SETTINGS,
"batch_sizes": [1, 4],
"prefix_sharing_algorithm": "trie",
},
{"batch_sizes": [1, 4], "prefix_sharing_algorithm": "none"},
{"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS},
),
pytest.param(
Expand All @@ -83,8 +102,8 @@ def do_generate(prompt, port):
"tokenizer_id": "openlm-research/open_llama_3b_v2",
"settings": CPU_SETTINGS,
"batch_sizes": [1, 4],
"prefix_sharing_algorithm": "none",
},
{"batch_sizes": [1, 4], "prefix_sharing_algorithm": "trie"},
{"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS},
),
],
Expand All @@ -100,17 +119,80 @@ def test_llm_server(llm_server, available_port):
"Sending HTTP Generation Request"
+ start_log_group("Sending HTTP Generation Request")
)
output = do_generate(PROMPT, available_port)
output = do_generate(PROMPT, available_port)[0]
# log to GITHUB_STEP_SUMMARY if we are in a GitHub Action
if "GITHUB_ACTION" in os.environ:
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
# log prompt
f.write("LLM results:\n")
f.write(f"- llm_prompt:`{PROMPT}`\n")
f.write(f"- llm_output:`{output}`\n")
logger.info(output)
if not output.startswith(expected_output_prefix):
raise AccuracyValidationException(
f"Expected '{output}' to start with '{expected_output_prefix}'"
)
logger.info("HTTP Generation Request Successful" + end_log_group())


@pytest.mark.parametrize(
"model_test_dir,write_config,llm_server",
[
pytest.param(
{
"repo_id": "SlyEcho/open_llama_3b_v2_gguf",
"model_file": "open-llama-3b-v2-f16.gguf",
"tokenizer_id": "openlm-research/open_llama_3b_v2",
"settings": CPU_SETTINGS,
"batch_sizes": [1, 4],
},
{"batch_sizes": [1, 4], "prefix_sharing_algorithm": "none"},
{"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS},
),
pytest.param(
{
"repo_id": "SlyEcho/open_llama_3b_v2_gguf",
"model_file": "open-llama-3b-v2-f16.gguf",
"tokenizer_id": "openlm-research/open_llama_3b_v2",
"settings": CPU_SETTINGS,
"batch_sizes": [1, 4],
},
{"batch_sizes": [1, 4], "prefix_sharing_algorithm": "trie"},
{"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS},
),
],
indirect=True,
)
@pytest.mark.parametrize(
"concurrent_requests",
[2, 4, 8],
)
@pytest.mark.xfail(
raises=AccuracyValidationException,
reason="Concurreny issues in Shortfin batch processing",
)
def test_llm_server_concurrent(llm_server, available_port, concurrent_requests):
logger.info("Testing concurrent invocations")

assert llm_server.poll() is None
PROMPT = "1 2 3 4 5 "
expected_output_prefix = "6 7 8"
logger.info(
"Sending HTTP Generation Request"
+ start_log_group("Sending HTTP Generation Request")
)
outputs = do_generate(PROMPT, available_port, concurrent_requests)

for output in outputs:
# log to GITHUB_STEP_SUMMARY if we are in a GitHub Action
if "GITHUB_ACTION" in os.environ:
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
# log prompt
f.write("LLM results:\n")
f.write(f"- llm_prompt:`{PROMPT}`\n")
f.write(f"- llm_output:`{output}`\n")

if not output.startswith(expected_output_prefix):
raise AccuracyValidationException(
f"Expected '{output}' to start with '{expected_output_prefix}'"
)
logger.info("HTTP Generation Request Successful" + end_log_group())

0 comments on commit 2faadd2

Please sign in to comment.