Skip to content

Commit

Permalink
switch to grpc for deploy and eval (#11643)
Browse files Browse the repository at this point in the history
* switch to grpc for deploy and eval

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

* add docstring

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

* Apply isort and black reformatting

Signed-off-by: HuiyingLi <HuiyingLi@users.noreply.github.com>

* remove rest service according to comment

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

* format fixes

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

* use existing query_llm

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

* Apply isort and black reformatting

Signed-off-by: HuiyingLi <HuiyingLi@users.noreply.github.com>

* minor fix and docstring changes

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

* remove conversion to list

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

---------

Signed-off-by: Huiying Li <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <HuiyingLi@users.noreply.github.com>
Co-authored-by: HuiyingLi <HuiyingLi@users.noreply.github.com>
  • Loading branch information
HuiyingLi and HuiyingLi authored Dec 30, 2024
1 parent 3394c22 commit 836c376
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 119 deletions.
57 changes: 9 additions & 48 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ def deploy(
model_type: str = "llama",
triton_model_name: str = "triton_model",
triton_model_version: Optional[int] = 1,
triton_port: int = 8000,
triton_http_port: int = 8000,
triton_grpc_port: int = 8001,
triton_http_address: str = "0.0.0.0",
triton_request_timeout: int = 60,
triton_model_repository: Path = None,
Expand All @@ -337,16 +338,10 @@ def deploy(
max_input_len: int = 256,
max_output_len: int = 256,
max_batch_size: int = 8,
start_rest_service: bool = True,
rest_service_http_address: str = "0.0.0.0",
rest_service_port: int = 8080,
openai_format_response: bool = True,
output_generation_logits: bool = True,
):
"""
Deploys nemo model on a PyTriton server by converting the nemo ckpt to trtllm.
Also starts rest service that is used to send OpenAI API compatible input request
to the PyTiton server.
Args:
nemo_checkpoint (Path): Path for nemo checkpoint.
Expand All @@ -355,7 +350,8 @@ def deploy(
name is passed to the evalute method for the model to be accessible while sending evalution requests.
Default: 'triton_model'.
triton_model_version (Optional[int]): Version for the triton model. Default: 1.
triton_port (int): Port for the PyTriton server. Default: 8000.
triton_http_port (int): HTTP port for the PyTriton server. Default: 8000.
triton_grpc_port (int): gRPC Port for the PyTriton server. Default: 8001.
triton_http_address (str): HTTP address for the PyTriton server. Default: "0.0.0.0".
triton_request_timeout (int): Timeout in seconds for Triton server. Default: 60.
triton_model_repository (Path): Folder for the trt-llm conversion, trt-llm engine gets saved in this specified
Expand All @@ -367,10 +363,7 @@ def deploy(
max_input_len (int): Max input length of the model. Default: 256.
max_output_len (int): Max output length of the model. Default: 256.
max_batch_size (int): Max batch size of the model. Default: 8.
start_rest_service (bool): Start rest service that is used to send evaluation requests to the PyTriton server.
Needs to be True to be able to run evaluation. Default: True.
rest_service_http_address (str): HTTP address for the rest service. Default: "0.0.0.0".
rest_service_port (int): Port for the rest service. Default: 8080.
openai_format_response (bool): Return the response from PyTriton server in OpenAI compatible format. Needs to
be True while running evaluation. Default: True.
output_generation_logits (bool): If True builds trtllm engine with gather_generation_logits set to True.
Expand All @@ -380,16 +373,6 @@ def deploy(
from nemo.deploy import DeployPyTriton

unset_environment_variables()
if start_rest_service:
if triton_port == rest_service_port:
logging.error("REST service port and Triton server port cannot use the same port.")
return
# Store triton ip, port and other args relevant for REST API as env vars to be accessible by rest_model_api.py
os.environ["TRITON_HTTP_ADDRESS"] = triton_http_address
os.environ["TRITON_PORT"] = str(triton_port)
os.environ["TRITON_REQUEST_TIMEOUT"] = str(triton_request_timeout)
os.environ["OPENAI_FORMAT_RESPONSE"] = str(openai_format_response)
os.environ["OUTPUT_GENERATION_LOGITS"] = str(output_generation_logits)

triton_deployable = get_trtllm_deployable(
nemo_checkpoint,
Expand All @@ -411,7 +394,8 @@ def deploy(
triton_model_name=triton_model_name,
triton_model_version=triton_model_version,
max_batch_size=max_batch_size,
port=triton_port,
http_port=triton_http_port,
grpc_port=triton_grpc_port,
address=triton_http_address,
)

Expand All @@ -422,26 +406,8 @@ def deploy(
logging.error("Error message has occurred during deploy function. Error message: " + str(error))
return

uvicorn_supported = True
try:
import uvicorn
except ImportError as error:
logging.warning(f"uvicorn could not be imported: {error}")
uvicorn_supported = False

try:
logging.info("Model serving on Triton is will be started.")
if start_rest_service and uvicorn_supported:
try:
logging.info("REST service will be started.")
uvicorn.run(
"nemo.deploy.service.rest_model_api:app",
host=rest_service_http_address,
port=rest_service_port,
reload=True,
)
except Exception as error:
logging.error("Error message has occurred during REST service start. Error message: " + str(error))
logging.info("Model serving on Triton will be started.")
nm.serve()
except Exception as error:
logging.error("Error message has occurred during deploy function. Error message: " + str(error))
Expand All @@ -453,7 +419,7 @@ def deploy(

def evaluate(
nemo_checkpoint_path: Path,
url: str = "http://0.0.0.0:8080/v1",
url: str = "grpc://0.0.0.0:8001",
model_name: str = "triton_model",
eval_task: str = "gsm8k",
num_fewshot: Optional[int] = None,
Expand All @@ -473,10 +439,7 @@ def evaluate(
Args:
nemo_checkpoint_path (Path): Path for nemo 2.0 checkpoint. This is used to get the tokenizer from the ckpt
which is required to tokenize the evaluation input and output prompts.
url (str): rest service url and port that were used in the deploy method above in the format:
http://{rest_service_http}:{rest_service_port}. Post requests with evaluation input prompts
(from lm-eval-harness) are sent to this url which is then passed to the model deployed on PyTriton server.
The rest service url and port serve as the entry point to evaluate model deployed on PyTriton server.
url (str): grpc service url that were used in the deploy method above in the format: grpc://{grpc_service_ip}:{grpc_port}.
model_name (str): Name of the model that is deployed on PyTriton server. It should be the same as
triton_model_name passed to the deploy method above to be able to launch evaluation. Deafult: "triton_model".
eval_task (str): task to be evaluated on. For ex: "gsm8k", "gsm8k_cot", "mmlu", "lambada". Default: "gsm8k".
Expand Down Expand Up @@ -513,8 +476,6 @@ def evaluate(

# Get tokenizer from nemo ckpt. This works only with NeMo 2.0 ckpt.
tokenizer = io.load_context(nemo_checkpoint_path + "/context", subpath="model.tokenizer")
# Wait for rest service to be ready before starting evaluation
evaluation.wait_for_rest_service(rest_url=f"{url}/v1/health")
# Create an object of the NeMoFWLM which is passed as a model to evaluator.simple_evaluate
model = evaluation.NeMoFWLMEval(
model_name, url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos
Expand Down
1 change: 0 additions & 1 deletion nemo/collections/llm/deploy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def get_trtllm_deployable(
trt_llm_exporter.export(
nemo_checkpoint_path=nemo_checkpoint,
model_type=model_type,
n_gpus=num_gpus,
tensor_parallelism_size=tensor_parallelism_size,
pipeline_parallelism_size=pipeline_parallelism_size,
max_input_len=max_input_len,
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from nemo.collections.llm.evaluation.base import NeMoFWLMEval, wait_for_rest_service
from nemo.collections.llm.evaluation.base import NeMoFWLMEval

__all__ = ["NeMoFWLMEval", "wait_for_rest_service"]
__all__ = ["NeMoFWLMEval"]
77 changes: 16 additions & 61 deletions nemo/collections/llm/evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import requests
import torch
import torch.nn.functional as F
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from requests.exceptions import RequestException
from tqdm import tqdm

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
from nemo.utils import logging
from nemo.deploy.nlp import NemoQueryLLM


class NeMoFWLMEval(LM):
Expand All @@ -49,21 +46,22 @@ def _generate_tokens_logits(self, payload, return_text: bool = False, return_log
A private method that sends post request to the model on PyTriton server and returns either generated text or
logits.
"""
# send a post request to /v1/completions/ endpoint with the payload
response = requests.post(f"{self.api_url}/v1/completions/", json=payload)
response_data = response.json()

if 'error' in response_data:
raise Exception(f"API Error: {response_data['error']}")
nq = NemoQueryLLM(url=self.api_url, model_name=payload['model'])

response = nq.query_llm(
prompts=payload['prompt'] if isinstance(payload['prompt'], list) else [payload['prompt']],
max_output_len=payload['max_tokens'],
top_k=payload['top_k'],
top_p=payload['top_p'],
temperature=payload['temperature'],
output_generation_logits=True,
openai_format_response=True,
)

# Assuming the response is in OpenAI format
if return_text:
# in case of generate_until tasks return just the text
return response_data['choices'][0]['text']

return response["choices"][0]["text"] # shape[batch_size, 1]
if return_logits:
# in case of loglikelihood tasks return the logits
return response_data['choices'][0]['generation_logits']
return response["choices"][0]["generation_logits"] # shape[batch_size, 1, num_tokens, vocab_size]

def tokenizer_type(self, tokenizer):
"""
Expand Down Expand Up @@ -93,7 +91,7 @@ def loglikelihood(self, requests: list[Instance]):
special_tokens_kwargs['add_special_tokens'] = self.add_bos

results = []
for request in requests:
for request in tqdm(requests):
# get the input prompt from the request
context = request.arguments[0]
# get the output prompt from the request
Expand Down Expand Up @@ -165,46 +163,3 @@ def generate_until(self, inputs: list[Instance]):
results.append(generated_text)

return results


def wait_for_rest_service(rest_url, max_retries=600, retry_interval=2):
"""
Wait for REST service to be ready.
Args:
rest_url (str): URL of the REST service's health endpoint
max_retries (int): Maximum number of retry attempts. Defaul: 60.
retry_interval (int): Time to wait between retries in seconds. Default: 2.
Returns:
bool: True if rest service is ready, False otherwise
"""

def check_service(url):
"""
Check if the service is ready by making a GET request to its health endpoint.
Args:
url (str): URL of the service's health endpoint
Returns:
bool: True if the service is ready, False otherwise
"""
try:
response = requests.get(url, timeout=5)
return response.status_code == 200
except RequestException:
return False

for _ in range(max_retries):
rest_ready = check_service(rest_url)

if rest_ready:
logging.info("REST service is ready.")
return True

logging.info(f"REST Service not ready yet. Retrying in {retry_interval} seconds...")
time.sleep(retry_interval)

logging.info("Timeout: REST service did not become ready.")
return False
6 changes: 4 additions & 2 deletions nemo/deploy/deploy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(
checkpoint_path: str = None,
model=None,
max_batch_size: int = 128,
port: int = 8000,
http_port: int = 8000,
grpc_port: int = 8001,
address="0.0.0.0",
allow_grpc=True,
allow_http=True,
Expand All @@ -54,7 +55,8 @@ def __init__(
self.triton_model_version = triton_model_version
self.max_batch_size = max_batch_size
self.model = model
self.port = port
self.http_port = http_port
self.grpc_port = grpc_port
self.address = address
self.triton = None
self.allow_grpc = allow_grpc
Expand Down
10 changes: 7 additions & 3 deletions nemo/deploy/deploy_pytriton.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(
checkpoint_path: str = None,
model=None,
max_batch_size: int = 128,
port: int = 8000,
http_port: int = 8000,
grpc_port: int = 8001,
address="0.0.0.0",
allow_grpc=True,
allow_http=True,
Expand All @@ -92,7 +93,8 @@ def __init__(
checkpoint_path=checkpoint_path,
model=model,
max_batch_size=max_batch_size,
port=port,
http_port=http_port,
grpc_port=grpc_port,
address=address,
allow_grpc=allow_grpc,
allow_http=allow_http,
Expand Down Expand Up @@ -128,7 +130,9 @@ def deploy(self):
else:
triton_config = TritonConfig(
http_address=self.address,
http_port=self.port,
http_port=self.http_port,
grpc_address=self.address,
grpc_port=self.grpc_port,
allow_grpc=self.allow_grpc,
allow_http=self.allow_http,
)
Expand Down
3 changes: 1 addition & 2 deletions nemo/deploy/nlp/query_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,8 @@ def query_llm(
"model": self.model_name,
"choices": [{"text": str(sentences)}],
}
# Convert gneration logits to a list to make it json serializable and add it to openai_response dict
if output_generation_logits:
openai_response["choices"][0]["generation_logits"] = result_dict["generation_logits"].tolist()
openai_response["choices"][0]["generation_logits"] = result_dict["generation_logits"]
return openai_response
else:
return sentences
Expand Down

0 comments on commit 836c376

Please sign in to comment.