diff --git a/python/packages/autogen-core/docs/src/reference/index.md b/python/packages/autogen-core/docs/src/reference/index.md index 869ffc2347c..fdaf598c002 100644 --- a/python/packages/autogen-core/docs/src/reference/index.md +++ b/python/packages/autogen-core/docs/src/reference/index.md @@ -51,6 +51,7 @@ python/autogen_ext.teams.magentic_one python/autogen_ext.models.cache python/autogen_ext.models.openai python/autogen_ext.models.replay +python/autogen_ext.models.azure python/autogen_ext.models.semantic_kernel python/autogen_ext.tools.langchain python/autogen_ext.tools.graphrag diff --git a/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst new file mode 100644 index 00000000000..64c16a5a57d --- /dev/null +++ b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst @@ -0,0 +1,8 @@ +autogen\_ext.models.azure +========================== + + +.. automodule:: autogen_ext.models.azure + :members: + :undoc-members: + :show-inheritance: diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index 20579c99bae..0e404db2d20 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -20,7 +20,11 @@ dependencies = [ [project.optional-dependencies] langchain = ["langchain_core~= 0.3.3"] -azure = ["azure-core", "azure-identity"] +azure = [ + "azure-ai-inference>=1.0.0b7", + "azure-core", + "azure-identity", +] docker = ["docker~=7.0"] openai = ["openai>=1.52.2", "tiktoken>=0.8.0", "aiofiles"] file-surfer = [ @@ -52,7 +56,6 @@ diskcache = [ redis = [ "redis>=5.2.1" ] - grpc = [ "grpcio~=1.62.0", # TODO: update this once we have a stable version. ] @@ -60,47 +63,36 @@ jupyter-executor = [ "ipykernel>=6.29.5", "nbclient>=0.10.2", ] - semantic-kernel-core = [ "semantic-kernel>=1.17.1", ] - semantic-kernel-google = [ "semantic-kernel[google]>=1.17.1", ] - semantic-kernel-hugging-face = [ "semantic-kernel[hugging_face]>=1.17.1", ] - semantic-kernel-mistralai = [ "semantic-kernel[mistralai]>=1.17.1", ] - semantic-kernel-ollama = [ "semantic-kernel[ollama]>=1.17.1", ] - semantic-kernel-onnx = [ "semantic-kernel[onnx]>=1.17.1", ] - semantic-kernel-anthropic = [ "semantic-kernel[anthropic]>=1.17.1", ] - semantic-kernel-pandas = [ "semantic-kernel[pandas]>=1.17.1", ] - semantic-kernel-aws = [ "semantic-kernel[aws]>=1.17.1", ] - semantic-kernel-dapr = [ "semantic-kernel[dapr]>=1.17.1", ] - semantic-kernel-all = [ "semantic-kernel[google,hugging_face,mistralai,ollama,onnx,anthropic,usearch,pandas,aws,dapr]>=1.17.1", ] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py new file mode 100644 index 00000000000..2dc7b9c70a9 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py @@ -0,0 +1,4 @@ +from ._azure_ai_client import AzureAIChatCompletionClient +from .config import AzureAIChatCompletionClientConfig + +__all__ = ["AzureAIChatCompletionClient", "AzureAIChatCompletionClientConfig"] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py new file mode 100644 index 00000000000..7e36a869862 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py @@ -0,0 +1,501 @@ +import asyncio +import re +import warnings +from asyncio import Task +from typing import Sequence, Optional, Mapping, Any, List, Unpack, Dict, cast +from inspect import getfullargspec +from azure.ai.inference.aio import ChatCompletionsClient +from azure.ai.inference.models import ( + ChatCompletions, + CompletionsFinishReason, + ChatCompletionsToolCall, + ChatCompletionsToolDefinition, + FunctionDefinition, + ContentItem, + TextContentItem, + ImageContentItem, + ImageUrl, + ImageDetailLevel, + StreamingChatCompletionsUpdate, + SystemMessage as AzureSystemMessage, + UserMessage as AzureUserMessage, + AssistantMessage as AzureAssistantMessage, + ToolMessage as AzureToolMessage, + FunctionCall as AzureFunctionCall, +) +from typing_extensions import AsyncGenerator, Union + +from autogen_core import CancellationToken +from autogen_core import FunctionCall, Image +from autogen_core.models import ( + ChatCompletionClient, + LLMMessage, + CreateResult, + ModelInfo, + RequestUsage, + UserMessage, + SystemMessage, + AssistantMessage, + FunctionExecutionResultMessage, + FinishReasons, +) +from autogen_core.tools import Tool, ToolSchema +from autogen_ext.models.azure.config import AzureAIChatCompletionClientConfig, GITHUB_MODELS_ENDPOINT + +create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs) + + +def _is_github_model(endpoint: str) -> bool: + return endpoint == GITHUB_MODELS_ENDPOINT + + +def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]: + result: List[ChatCompletionsToolDefinition] = [] + for tool in tools: + if isinstance(tool, Tool): + tool_schema = tool.schema.copy() + else: + assert isinstance(tool, dict) + tool_schema = tool.copy() + # tool_schema["parameters"] = {k:v for k,v in tool_schema["parameters"].items()} + # azure_ai_schema = {k:v for k,v in tool_schema["parameters"].items()} + + for key, value in tool_schema["parameters"]["properties"].items(): + if "title" in value.keys(): + del value["title"] + + result.append( + ChatCompletionsToolDefinition( + function=FunctionDefinition( + name=tool_schema["name"], + description=(tool_schema["description"] if "description" in tool_schema else ""), + parameters=(tool_schema["parameters"]) if "parameters" in tool_schema else {}, + ), + ), + ) + return result + + +def _func_call_to_azure(message: FunctionCall) -> ChatCompletionsToolCall: + return ChatCompletionsToolCall( + id=message.id, + function=AzureFunctionCall(arguments=message.arguments, name=message.name), + ) + + +def _system_message_to_azure(message: SystemMessage) -> AzureSystemMessage: + return AzureSystemMessage(content=message.content) + + +def _user_message_to_azure(message: UserMessage) -> AzureUserMessage: + assert_valid_name(message.source) + if isinstance(message.content, str): + return AzureUserMessage(content=message.content) + else: + parts: List[ContentItem] = [] + for part in message.content: + if isinstance(part, str): + parts.append(TextContentItem(text=part)) + elif isinstance(part, Image): + # TODO: support url based images + # TODO: support specifying details + parts.append(ImageContentItem(image_url=ImageUrl(url=part.data_uri, detail=ImageDetailLevel.AUTO))) + else: + raise ValueError(f"Unknown content type: {message.content}") + return AzureUserMessage(content=parts) + + +def _assistant_message_to_azure(message: AssistantMessage) -> AzureAssistantMessage: + assert_valid_name(message.source) + if isinstance(message.content, list): + return AzureAssistantMessage( + tool_calls=[_func_call_to_azure(x) for x in message.content], + ) + else: + return AzureAssistantMessage(content=message.content) + + +def _tool_message_to_azure(message: FunctionExecutionResultMessage) -> Sequence[AzureToolMessage]: + return [AzureToolMessage(content=x.content, tool_call_id=x.call_id) for x in message.content] + + +def to_azure_message(message: LLMMessage): + if isinstance(message, SystemMessage): + return [_system_message_to_azure(message)] + elif isinstance(message, UserMessage): + return [_user_message_to_azure(message)] + elif isinstance(message, AssistantMessage): + return [_assistant_message_to_azure(message)] + else: + return _tool_message_to_azure(message) + + +def normalize_name(name: str) -> str: + """ + LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_". + + Prefer _assert_valid_name for validating user configuration or input + """ + return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64] + + +def assert_valid_name(name: str) -> str: + """ + Ensure that configured names are valid, raises ValueError if not. + + For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API. + """ + if not re.match(r"^[a-zA-Z0-9_-]+$", name): + raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.") + if len(name) > 64: + raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.") + return name + + +def normalize_stop_reason(stop_reason: str|None) -> FinishReasons: + if stop_reason is None: + return "unknown" + + stop_reason = stop_reason.lower() + + KNOWN_STOP_MAPPINGS: Dict[str, FinishReasons] = { + "end_turn": "stop", + "tool_calls": "function_calls", + } + + return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown") + + +class AzureAIChatCompletionClient(ChatCompletionClient): + """ + Chat completion client for models hosted on Azure AI Foundry or GitHub Models. + See `here `_ for more info. + + Args: + endpoint (str): The endpoint to use. **Required.** + credentials (union, AzureKeyCredential, AsyncTokenCredential): The credentials to use. **Required** + model_info (ModelInfo): The capabilities of the model. **Required.** + model (str): The name of the model. **Required if model is hosted on GitHub Models.** + frequency_penalty: (optional,float) + presence_penalty: (optional,float) + temperature: (optional,float) + top_p: (optional,float) + max_tokens: (optional,int) + response_format: (optional,ChatCompletionsResponseFormat) + stop: (optional,List[str]) + tools: (optional,List[ChatCompletionsToolDefinition]) + tool_choice: (optional,Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]]) + seed: (optional,int) + model_extras: (optional,Dict[str, Any]) + + To use this client, you must install the `azure-ai-inference` extension: + + .. code-block:: bash + + pip install 'autogen-ext[azure-ai-inference]==0.4.0.dev11' + + The following code snippet shows how to use the client: + + .. code-block:: python + + from azure.core.credentials import AzureKeyCredential + from autogen_ext.models.azure import AzureAIChatCompletionClient + from autogen_core.models import UserMessage + + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_info={ + "family": "unknown", + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + + result = await client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore + print(result) + + """ + + def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]): + config = self._validate_config(kwargs) + self._model_info = config["model_info"] + self._client = self._create_client(config) + self._create_args = self._prepare_create_args(config) + + self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + + @staticmethod + def _validate_config(config: Dict) -> AzureAIChatCompletionClientConfig: + if "endpoint" not in config: + raise ValueError("endpoint is required for AzureAIChatCompletionClient") + if "credential" not in config: + raise ValueError("credential is required for AzureAIChatCompletionClient") + if "model_info" not in config: + raise ValueError("model_info is required for AzureAIChatCompletionClient") + if _is_github_model(config["endpoint"]) and "model" not in config: + raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient") + return config + + @staticmethod + def _create_client(config: AzureAIChatCompletionClientConfig): + return ChatCompletionsClient(**config) + + @staticmethod + def _prepare_create_args(config: Mapping[str, Any]) -> Mapping[str, Any]: + create_args = {k: v for k, v in config.items() if k in create_kwargs} + return create_args + # self._endpoint = config.pop("endpoint") + # self._credential = config.pop("credential") + # self._model_capabilities = config.pop("model_capabilities") + # self._create_args = config.copy() + + def add_usage(self, usage: RequestUsage): + self._total_usage = RequestUsage( + self._total_usage.prompt_tokens + usage.prompt_tokens, + self._total_usage.completion_tokens + usage.completion_tokens, + ) + + async def create( + self, + messages: Sequence[LLMMessage], + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + ) -> CreateResult: + extra_create_args_keys = set(extra_create_args.keys()) + if not create_kwargs.issuperset(extra_create_args_keys): + raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}") + + # Copy the create args and overwrite anything in extra_create_args + create_args = self._create_args.copy() + create_args.update(extra_create_args) + + if self.model_info["vision"] is False: + for message in messages: + if isinstance(message, UserMessage): + if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): + raise ValueError("Model does not support vision and image was provided") + + if json_output is not None: + if self.model_info["json_output"] is False and json_output is True: + raise ValueError("Model does not support JSON output") + + if json_output is True and "response_format" not in create_args: + create_args["response_format"] = "json-object" + + if self.model_info["json_output"] is False and json_output is True: + raise ValueError("Model does not support JSON output") + if self.model_info["function_calling"] is False and len(tools) > 0: + raise ValueError("Model does not support function calling") + + azure_messages_nested = [to_azure_message(msg) for msg in messages] + azure_messages = [item for sublist in azure_messages_nested for item in sublist] + + task: Task[ChatCompletions] + + if len(tools) > 0: + converted_tools = convert_tools(tools) + task = asyncio.create_task( + self._client.complete(messages=azure_messages, tools=converted_tools, **create_args) + ) + else: + task = asyncio.create_task( + self._client.complete( + messages=azure_messages, + **create_args, + ) + ) + + if cancellation_token is not None: + cancellation_token.link_future(task) + + result: ChatCompletions = await task + + usage = RequestUsage( + prompt_tokens=result.usage.prompt_tokens if result.usage else 0, + completion_tokens=result.usage.completion_tokens if result.usage else 0, + ) + + choice = result.choices[0] + if choice.finish_reason == CompletionsFinishReason.TOOL_CALLS: + assert choice.message.tool_calls is not None + + content = [ + FunctionCall( + id=x.id, + arguments=x.function.arguments, + name=normalize_name(x.function.name), + ) + for x in choice.message.tool_calls + ] + finish_reason = "function_calls" + else: + finish_reason = choice.finish_reason + content = choice.message.content or "" + + response = CreateResult( + finish_reason=normalize_stop_reason(finish_reason.value), # type: ignore + content=content, + usage=usage, + cached=False, + ) + + self.add_usage(usage) + + return response + + async def create_stream( + self, + messages: Sequence[LLMMessage], + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + ) -> AsyncGenerator[Union[str, CreateResult], None]: + extra_create_args_keys = set(extra_create_args.keys()) + if not create_kwargs.issuperset(extra_create_args_keys): + raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}") + + create_args = self._create_args.copy() + create_args.update(extra_create_args) + + if self.model_info["vision"] is False: + for message in messages: + if isinstance(message, UserMessage): + if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): + raise ValueError("Model does not support vision and image was provided") + + if json_output is not None: + if self.model_info["json_output"] is False and json_output is True: + raise ValueError("Model does not support JSON output") + + if json_output is True and "response_format" not in create_args: + create_args["response_format"] = "json-object" + + if self.model_info["json_output"] is False and json_output is True: + raise ValueError("Model does not support JSON output") + if self.model_info["function_calling"] is False and len(tools) > 0: + raise ValueError("Model does not support function calling") + + # azure_messages = [to_azure_message(m) for m in messages] + azure_messages_nested = [to_azure_message(msg) for msg in messages] + azure_messages = [item for sublist in azure_messages_nested for item in sublist] + + # task: Task[StreamingChatCompletionsUpdate] + + if len(tools) > 0: + converted_tools = convert_tools(tools) + task = asyncio.create_task( + self._client.complete(messages=azure_messages, tools=converted_tools, stream=True, **create_args) + ) + else: + task = asyncio.create_task( + self._client.complete(messages=azure_messages, max_tokens=20, stream=True, **create_args) + ) + + if cancellation_token is not None: + cancellation_token.link_future(task) + + # result: ChatCompletions = await task + finish_reason = None + content_deltas: List[str] = [] + full_tool_calls: Dict[str, FunctionCall] = {} + prompt_tokens = 0 + completion_tokens = 0 + chunk: Optional[StreamingChatCompletionsUpdate] = None + async for chunk in await task: + choice = chunk.choices[0] if len(chunk.choices) > 0 else cast(StreamingChatCompletionsUpdate, None) + if choice.finish_reason is not None: + finish_reason = choice.finish_reason.value + + # We first try to load the content + if choice.delta.content is not None: + content_deltas.append(choice.delta.content) + yield choice.delta.content + # Otherwise, we try to load the tool calls + if choice.delta.tool_calls is not None: + for tool_call_chunk in choice.delta.tool_calls: + # print(tool_call_chunk) + if "index" in tool_call_chunk: + idx = tool_call_chunk["index"] + else: + idx = tool_call_chunk.id + if idx not in full_tool_calls: + full_tool_calls[idx] = FunctionCall(id="", arguments="", name="") + + if tool_call_chunk.id is not None: + full_tool_calls[idx].id += tool_call_chunk.id + + if tool_call_chunk.function is not None: + if tool_call_chunk.function.name is not None: + full_tool_calls[idx].name += tool_call_chunk.function.name + if tool_call_chunk.function.arguments is not None: + full_tool_calls[idx].arguments += tool_call_chunk.function.arguments + + if chunk and chunk.usage: + prompt_tokens = chunk.usage.prompt_tokens + + if finish_reason is None: + raise ValueError("No stop reason found") + + if choice and choice.finish_reason is CompletionsFinishReason.TOOL_CALLS: + finish_reason = "function_calls" + + content: Union[str, List[FunctionCall]] + + if len(content_deltas) > 1: + content = "".join(content_deltas) + if chunk and chunk.usage: + completion_tokens = chunk.usage.completion_tokens + else: + completion_tokens = 0 + else: + content = list(full_tool_calls.values()) + + usage = RequestUsage( + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + ) + + result = CreateResult( + finish_reason=normalize_stop_reason(finish_reason), # type: ignore + content=content, + usage=usage, + cached=False, + ) + + self.add_usage(usage) + + yield result + + def actual_usage(self) -> RequestUsage: + return self._actual_usage + + def total_usage(self) -> RequestUsage: + return self._total_usage + + def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + return 0 + + def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + return 0 + + @property + def capabilities(self) -> ModelInfo: # type: ignore + warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2) + return self._model_info + + @property + def model_info(self) -> ModelInfo: + return self._model_info + + def __del__(self): + # TODO: This is a hack to close the open client + try: + asyncio.get_running_loop().create_task(self._client.close()) + except RuntimeError: + asyncio.run(self._client.close()) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py new file mode 100644 index 00000000000..492f868fc20 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py @@ -0,0 +1,38 @@ +from typing import TypedDict, Union, Optional, List, Dict, Any +from azure.ai.inference.models import ( + JsonSchemaFormat, + ChatCompletionsToolDefinition, + ChatCompletionsToolChoicePreset, + ChatCompletionsNamedToolChoice, +) + +from azure.core.credentials import AzureKeyCredential +from azure.core.credentials_async import AsyncTokenCredential +from autogen_core.models import ModelInfo + +GITHUB_MODELS_ENDPOINT = "https://models.inference.ai.azure.com" + + +class AzureAIClientArguments(TypedDict, total=False): + endpoint: str + credential: Union[AzureKeyCredential, AsyncTokenCredential] + model_info: ModelInfo + + +class AzureAICreateArguments(TypedDict, total=False): + frequency_penalty: Optional[float] + presence_penalty: Optional[float] + temperature: Optional[float] + top_p: Optional[float] + max_tokens: Optional[int] + response_format: Optional[Union[str, JsonSchemaFormat]] + stop: Optional[List[str]] + tools: Optional[List[ChatCompletionsToolDefinition]] + tool_choice: Optional[Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]] + seed: Optional[int] + model: Optional[str] + model_extras: Optional[Dict[str, Any]] + + +class AzureAIChatCompletionClientConfig(AzureAIClientArguments, AzureAICreateArguments): + pass diff --git a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py new file mode 100644 index 00000000000..e18d5ea2280 --- /dev/null +++ b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py @@ -0,0 +1,174 @@ +import asyncio +from datetime import datetime +from typing import AsyncGenerator, Any + +import pytest +from azure.ai.inference.aio import ( + ChatCompletionsClient, +) + + +from azure.ai.inference.models import ( + ChatChoice, + ChatResponseMessage, + CompletionsUsage, +) + +from azure.ai.inference.models import ( + ChatCompletions, + StreamingChatCompletionsUpdate, + StreamingChatChoiceUpdate, + StreamingChatResponseMessageUpdate, +) + +from azure.core.credentials import AzureKeyCredential + +from autogen_core import CancellationToken +from autogen_core.models import UserMessage +from autogen_ext.models.azure import AzureAIChatCompletionClient + + +async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]: + mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"] + + mock_chunks = [ + StreamingChatChoiceUpdate( + index=0, + finish_reason="stop", + delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content), + ) + for chunk_content in mock_chunks_content + ] + + for mock_chunk in mock_chunks: + await asyncio.sleep(0.1) + yield StreamingChatCompletionsUpdate( + id="id", + choices=[mock_chunk], + created=datetime.now(), + model="model", + usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + + +async def _mock_create( + *args: Any, **kwargs: Any +) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]: + stream = kwargs.get("stream", False) + + if not stream: + await asyncio.sleep(0.1) + return ChatCompletions( + id="id", + created=datetime.now(), + model="model", + choices=[ + ChatChoice( + index=0, finish_reason="stop", message=ChatResponseMessage(content="Hello", role="assistant") + ) + ], + usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + else: + return _mock_create_stream(*args, **kwargs) + + +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client() -> None: + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_info={ + "family": "unknown", + "json_output": False, + "function_calling": False, + "vision": False, + }, + model="model", + ) + assert client + + +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.MonkeyPatch) -> None: + # monkeypatch.setattr(AsyncCompletions, "create", _mock_create) + monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_info={ + "family": "unknown", + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + result = await client.create(messages=[UserMessage(content="Hello", source="user")]) + assert result.content == "Hello" + + +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client_create_stream(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) + chunks = [] + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_info={ + "family": "unknown", + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]): + chunks.append(chunk) + + assert chunks[0] == "Hello" + assert chunks[1] == " Another Hello" + assert chunks[2] == " Yet Another Hello" + + +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) + cancellation_token = CancellationToken() + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_info={ + "family": "unknown", + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + task = asyncio.create_task( + client.create(messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token) + ) + cancellation_token.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) + cancellation_token = CancellationToken() + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_info={ + "family": "unknown", + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + stream = client.create_stream( + messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token + ) + cancellation_token.cancel() + with pytest.raises(asyncio.CancelledError): + async for _ in stream: + pass