From daf43de5283783cb102f16eebeda2d920a446f29 Mon Sep 17 00:00:00 2001 From: Rohan Thacker Date: Tue, 24 Dec 2024 13:15:38 +0530 Subject: [PATCH] WIP: Azure AI Client * Added: object-level usage data * Added: doc string * Added: check existing response_format value * Added: _validate_config and _create_client --- .../autogen-core/docs/src/reference/index.md | 1 + .../python/autogen_ext.models.azure.rst | 8 ++ .../src/autogen_ext/models/azure/__init__.py | 3 +- .../models/azure/_azure_ai_client.py | 109 +++++++++++++++--- .../models/test_azure_ai_model_client.py | 42 ++++--- 5 files changed, 131 insertions(+), 32 deletions(-) create mode 100644 python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst diff --git a/python/packages/autogen-core/docs/src/reference/index.md b/python/packages/autogen-core/docs/src/reference/index.md index cfe36eded2c2..4893b9964b93 100644 --- a/python/packages/autogen-core/docs/src/reference/index.md +++ b/python/packages/autogen-core/docs/src/reference/index.md @@ -48,6 +48,7 @@ python/autogen_ext.agents.video_surfer.tools python/autogen_ext.teams.magentic_one python/autogen_ext.models.openai python/autogen_ext.models.replay +python/autogen_ext.models.azure python/autogen_ext.tools.langchain python/autogen_ext.code_executors.local python/autogen_ext.code_executors.docker diff --git a/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst new file mode 100644 index 000000000000..64c16a5a57d4 --- /dev/null +++ b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst @@ -0,0 +1,8 @@ +autogen\_ext.models.azure +========================== + + +.. automodule:: autogen_ext.models.azure + :members: + :undoc-members: + :show-inheritance: diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py index 02d4392e5a8b..2dc7b9c70a98 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py @@ -1,3 +1,4 @@ from ._azure_ai_client import AzureAIChatCompletionClient +from .config import AzureAIChatCompletionClientConfig -__all__ = ["AzureAIChatCompletionClient"] +__all__ = ["AzureAIChatCompletionClient", "AzureAIChatCompletionClientConfig"] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py index 4cd058808ddd..1297060601fa 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py @@ -1,5 +1,6 @@ import asyncio import re +import warnings from asyncio import Task from typing import Sequence, Optional, Mapping, Any, List, Unpack, Dict, cast from inspect import getfullargspec @@ -154,25 +155,95 @@ def assert_valid_name(name: str) -> str: class AzureAIChatCompletionClient(ChatCompletionClient): + """ + Chat completion client for models hosted on Azure AI Foundry or GitHub Models. + See `here `_ for more info. + + Args: + endpoint (str): The endpoint to use. **Required.** + credentials (union, AzureKeyCredential, AsyncTokenCredential): The credentials to use. **Required** + model_capabilities (ModelCapabilities): The capabilities of the model. **Required.** + model (str): The name of the model. **Required if model is hosted on GitHub Models.** + frequency_penalty: (optional,float) + presence_penalty: (optional,float) + temperature: (optional,float) + top_p: (optional,float) + max_tokens: (optional,int) + response_format: (optional,ChatCompletionsResponseFormat) + stop: (optional,List[str]) + tools: (optional,List[ChatCompletionsToolDefinition]) + tool_choice: (optional,Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]]) + seed: (optional,int) + model_extras: (optional,Dict[str, Any]) + + To use this client, you must install the `azure-ai-inference` extension: + + .. code-block:: bash + + pip install 'autogen-ext[azure-ai-inference]==0.4.0.dev11' + + The following code snippet shows how to use the client: + + .. code-block:: python + + from azure.core.credentials import AzureKeyCredential + from autogen_ext.models.azure import AzureAIChatCompletionClient + from autogen_core.models import UserMessage + + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_capabilities={ + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + + result = await client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore + print(result) + + """ + def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]): - if "endpoint" not in kwargs: + config = self._validate_config(kwargs) + self._model_capabilities = config["model_capabilities"] + self._client = self._create_client(config) + self._create_args = self._prepare_create_args(config) + + self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + + @staticmethod + def _validate_config(config: Dict) -> AzureAIChatCompletionClientConfig: + if "endpoint" not in config: raise ValueError("endpoint is required for AzureAIChatCompletionClient") - if "credential" not in kwargs: + if "credential" not in config: raise ValueError("credential is required for AzureAIChatCompletionClient") - if "model_capabilities" not in kwargs: + if "model_capabilities" not in config: raise ValueError("model_capabilities is required for AzureAIChatCompletionClient") - if _is_github_model(kwargs['endpoint']) and "model" not in kwargs: + if _is_github_model(config["endpoint"]) and "model" not in config: raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient") - - # TODO: Change - _endpoint = kwargs.pop("endpoint") - _credential = kwargs.pop("credential") - self._model_capabilities = kwargs.pop("model_capabilities") - self._create_args = kwargs.copy() - - self._client = ChatCompletionsClient(_endpoint, _credential, **self._create_args) - self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) - self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + return config + + @staticmethod + def _create_client(config: AzureAIChatCompletionClientConfig): + return ChatCompletionsClient(**config) + + @staticmethod + def _prepare_create_args(config: Mapping[str, Any]) -> Mapping[str, Any]: + create_args = {k: v for k, v in config.items() if k in create_kwargs} + return create_args + # self._endpoint = config.pop("endpoint") + # self._credential = config.pop("credential") + # self._model_capabilities = config.pop("model_capabilities") + # self._create_args = config.copy() + + def add_usage(self, usage: RequestUsage): + self._total_usage = RequestUsage( + self._total_usage.prompt_tokens + usage.prompt_tokens, + self._total_usage.completion_tokens + usage.completion_tokens, + ) async def create( self, @@ -200,7 +271,7 @@ async def create( if self.capabilities["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output") - if json_output is True: + if json_output is True and "response_format" not in create_args: create_args["response_format"] = ChatCompletionsResponseFormatJSON() if self.capabilities["json_output"] is False and json_output is True: @@ -259,6 +330,9 @@ async def create( usage=usage, cached=False, ) + + self.add_usage(usage) + return response async def create_stream( @@ -286,7 +360,7 @@ async def create_stream( if self.capabilities["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output") - if json_output is True: + if json_output is True and "response_format" not in create_args: create_args["response_format"] = ChatCompletionsResponseFormatJSON() if self.capabilities["json_output"] is False and json_output is True: @@ -380,6 +454,9 @@ async def create_stream( usage=usage, cached=False, ) + + self.add_usage(usage) + yield result def actual_usage(self) -> RequestUsage: diff --git a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py index fd888d7f4b45..22bd7cf74ee1 100644 --- a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py @@ -7,15 +7,20 @@ ChatCompletionsClient, ) + from azure.ai.inference.models import ( ChatChoice, ChatResponseMessage, CompletionsUsage, + ChatCompletionsResponseFormatJSON, +) +from azure.ai.inference.models import ( + ChatCompletions, + StreamingChatCompletionsUpdate, + StreamingChatChoiceUpdate, + StreamingChatResponseMessageUpdate, ) -from azure.ai.inference.models import (ChatCompletions, - StreamingChatCompletionsUpdate, StreamingChatChoiceUpdate, - StreamingChatResponseMessageUpdate) from azure.core.credentials import AzureKeyCredential @@ -32,7 +37,8 @@ async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[Strea index=0, finish_reason="stop", delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content), - ) for chunk_content in mock_chunks_content + ) + for chunk_content in mock_chunks_content ] for mock_chunk in mock_chunks: @@ -46,7 +52,9 @@ async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[Strea ) -async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]: +async def _mock_create( + *args: Any, **kwargs: Any +) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]: stream = kwargs.get("stream", False) if not stream: @@ -54,12 +62,10 @@ async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletions | AsyncGene return ChatCompletions( id="id", created=datetime.now(), - model='model', + model="model", choices=[ ChatChoice( - index=0, - finish_reason="stop", - message=ChatResponseMessage(content="Hello", role="assistant") + index=0, finish_reason="stop", message=ChatResponseMessage(content="Hello", role="assistant") ) ], usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), @@ -68,20 +74,21 @@ async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletions | AsyncGene return _mock_create_stream(*args, **kwargs) - @pytest.mark.asyncio async def test_azure_ai_chat_completion_client() -> None: client = AzureAIChatCompletionClient( endpoint="endpoint", credential=AzureKeyCredential("api_key"), - model_capabilities = { + model_capabilities={ "json_output": False, "function_calling": False, "vision": False, }, + model="model", ) assert client + @pytest.mark.asyncio async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.MonkeyPatch) -> None: # monkeypatch.setattr(AsyncCompletions, "create", _mock_create) @@ -89,7 +96,7 @@ async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.Monkey client = AzureAIChatCompletionClient( endpoint="endpoint", credential=AzureKeyCredential("api_key"), - model_capabilities = { + model_capabilities={ "json_output": False, "function_calling": False, "vision": False, @@ -98,14 +105,15 @@ async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.Monkey result = await client.create(messages=[UserMessage(content="Hello", source="user")]) assert result.content == "Hello" + @pytest.mark.asyncio -async def test_azure_ai_chat_completion_client_create_stream(monkeypatch:pytest.MonkeyPatch) -> None: +async def test_azure_ai_chat_completion_client_create_stream(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) chunks = [] client = AzureAIChatCompletionClient( endpoint="endpoint", credential=AzureKeyCredential("api_key"), - model_capabilities = { + model_capabilities={ "json_output": False, "function_calling": False, "vision": False, @@ -118,6 +126,7 @@ async def test_azure_ai_chat_completion_client_create_stream(monkeypatch:pytest. assert chunks[1] == " Another Hello" assert chunks[2] == " Yet Another Hello" + @pytest.mark.asyncio async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) @@ -138,6 +147,7 @@ async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest with pytest.raises(asyncio.CancelledError): await task + @pytest.mark.asyncio async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) @@ -151,7 +161,9 @@ async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch: "vision": False, }, ) - stream=client.create_stream(messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token) + stream = client.create_stream( + messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token + ) cancellation_token.cancel() with pytest.raises(asyncio.CancelledError): async for _ in stream: