Skip to content

Commit

Permalink
Make DatabricksDeploymentClient support prediction with streaming res…
Browse files Browse the repository at this point in the history
…ponse (mlflow#11580)

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 authored Apr 4, 2024
1 parent b3a2ee2 commit 745f0e0
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 1 deletion.
14 changes: 14 additions & 0 deletions mlflow/deployments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,20 @@ def predict(self, deployment_name=None, inputs=None, endpoint=None):
"""
pass

def predict_stream(self, deployment_name=None, inputs=None, endpoint=None):
"""
Submit a query to a configured provider endpoint, and get streaming response
Args:
deployment_name: Name of deployment to predict against.
inputs: The inputs to the query, as a dictionary.
endpoint: The name of the endpoint to query.
Returns:
An iterator of dictionary containing the response from the endpoint.
"""
raise NotImplementedError()

def explain(self, deployment_name=None, df=None, endpoint=None):
"""
Generate explanations of model predictions on the specified input pandas Dataframe
Expand Down
116 changes: 115 additions & 1 deletion mlflow/deployments/databricks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import posixpath
from typing import Any, Dict, Optional
from typing import Any, Dict, Iterator, Optional

from mlflow.deployments import BaseDeploymentClient
from mlflow.deployments.constants import (
Expand All @@ -9,6 +10,7 @@
MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT,
MLFLOW_HTTP_REQUEST_TIMEOUT,
)
from mlflow.exceptions import MlflowException
from mlflow.utils import AttrDict
from mlflow.utils.annotations import experimental
from mlflow.utils.databricks_utils import get_databricks_host_creds
Expand Down Expand Up @@ -147,6 +149,42 @@ def _call_endpoint(
augmented_raise_for_status(response)
return DatabricksEndpoint(response.json())

def _call_endpoint_stream(
self,
*,
method: str,
prefix: str = "/api/2.0",
route: Optional[str] = None,
json_body: Optional[Dict[str, Any]] = None,
timeout: Optional[int] = None,
) -> Iterator[str]:
call_kwargs = {}
if method.lower() == "get":
call_kwargs["params"] = json_body
else:
call_kwargs["json"] = json_body

response = http_request(
host_creds=get_databricks_host_creds(self.target_uri),
endpoint=posixpath.join(prefix, "serving-endpoints", route or ""),
method=method,
timeout=MLFLOW_HTTP_REQUEST_TIMEOUT.get() if timeout is None else timeout,
raise_on_status=False,
retry_codes=MLFLOW_DEPLOYMENT_CLIENT_REQUEST_RETRY_CODES,
extra_headers={"X-Databricks-Endpoints-API-Client": "Databricks Deployment Client"},
stream=True, # Receive response content in streaming way.
**call_kwargs,
)
augmented_raise_for_status(response)

# Streaming response content are composed of multiple lines.
# Each line format depends on specific endpoint
return (
line.strip()
for line in response.iter_lines(decode_unicode=True)
if line.strip() # filter out keep-alive new lines
)

@experimental
def predict(self, deployment_name=None, inputs=None, endpoint=None):
"""
Expand Down Expand Up @@ -207,6 +245,82 @@ def predict(self, deployment_name=None, inputs=None, endpoint=None):
timeout=MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT.get(),
)

@experimental
def predict_stream(
self, deployment_name=None, inputs=None, endpoint=None
) -> Iterator[Dict[str, Any]]:
"""
Submit a query to a configured provider endpoint, and get streaming response
Args:
deployment_name: Unused.
inputs: The inputs to the query, as a dictionary.
endpoint: The name of the endpoint to query.
Returns:
An iterator of dictionary containing the response from the endpoint.
Example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
chunk_iter = client.predict_stream(
endpoint="databricks-llama-2-70b-chat",
inputs={
"messages": [{"role": "user", "content": "Hello!"}],
"temperature": 0.0,
"n": 1,
"max_tokens": 500,
},
)
for chunk in chunk_iter:
print(chunk)
# Example:
# {
# "id": "82a834f5-089d-4fc0-ad6c-db5c7d6a6129",
# "object": "chat.completion.chunk",
# "created": 1712133837,
# "model": "llama-2-70b-chat-030424",
# "choices": [
# {
# "index": 0, "delta": {"role": "assistant", "content": "Hello"},
# "finish_reason": None,
# }
# ],
# "usage": {"prompt_tokens": 11, "completion_tokens": 1, "total_tokens": 12},
# }
"""
inputs = inputs or {}

# Add stream=True param in request body to get streaming response
# See https://docs.databricks.com/api/workspace/servingendpoints/query#stream
chunk_line_iter = self._call_endpoint_stream(
method="POST",
prefix="/",
route=posixpath.join(endpoint, "invocations"),
json_body={**inputs, "stream": True},
timeout=MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT.get(),
)

for line in chunk_line_iter:
splits = line.split(":", 1)
if len(splits) < 2:
raise MlflowException(f"Unknown streaming response format: '{line}'.")
key, value = splits
if key != "data":
raise MlflowException(f"Unknown streaming response format with key '{key}'.")

value = value.strip()
if value == "[DONE]":
# Databricks endpoint streaming response ends with
# a line of "data: [DONE]"
return

yield json.loads(value)

@experimental
def create_endpoint(self, name, config=None):
"""
Expand Down

0 comments on commit 745f0e0

Please sign in to comment.