Skip to content

Commit

Permalink
Merge pull request #755 from Mirascope/release/v1.14
Browse files Browse the repository at this point in the history
Release/v1.14
  • Loading branch information
willbakst authored Dec 21, 2024
2 parents 23c3614 + 337fce4 commit fee724b
Show file tree
Hide file tree
Showing 54 changed files with 826 additions and 240 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
- id: ruff-format
files: "^mirascope|^tests|^examples|^docs"
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.390
rev: v1.1.391
hooks:
- id: pyright
- repo: local
Expand Down
14 changes: 6 additions & 8 deletions mirascope/core/anthropic/_utils/_setup_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AsyncAnthropicVertex,
)
from anthropic.types import Message, MessageParam, MessageStreamEvent
from pydantic import BaseModel

from ...base import BaseMessageParam, BaseTool, _utils
from ...base._utils import AsyncCreateFn, CreateFn
Expand All @@ -36,7 +37,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AnthropicCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
AsyncCreateFn[Message, MessageStreamEvent],
Expand All @@ -58,7 +59,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AnthropicCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
CreateFn[Message, MessageStreamEvent],
Expand All @@ -85,7 +86,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AnthropicCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
Callable[..., Message | Awaitable[Message]],
Expand All @@ -111,17 +112,14 @@ def setup_call(
call_kwargs["system"] = messages.pop(0)["content"] # pyright: ignore [reportGeneralTypeIssues]

if json_mode:
json_mode_content = _utils.json_mode_content(
tool_types[0] if tool_types else None
)
json_mode_content = _utils.json_mode_content(response_model)
if isinstance(messages[-1]["content"], str):
messages[-1]["content"] += json_mode_content
else:
messages[-1]["content"] = list(messages[-1]["content"]) + [
{"type": "text", "text": json_mode_content}
]
call_kwargs.pop("tools", None)
elif extract:
elif response_model:
assert tool_types, "At least one tool must be provided for extraction."
call_kwargs["tool_choice"] = {"type": "tool", "name": tool_types[0]._name()}
call_kwargs |= {
Expand Down
8 changes: 5 additions & 3 deletions mirascope/core/anthropic/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
usage docs: learn/calls.md#handling-responses
"""

from functools import cached_property

from anthropic.types import (
Message,
MessageParam,
Expand Down Expand Up @@ -97,13 +99,13 @@ def cost(self) -> float | None:
return calculate_cost(self.input_tokens, self.output_tokens, self.model)

@computed_field
@property
@cached_property
def message_param(self) -> SerializeAsAny[MessageParam]:
"""Returns the assistants's response as a message parameter."""
return MessageParam(**self.response.model_dump(include={"content", "role"}))

@computed_field
@property
@cached_property
def tools(self) -> list[AnthropicTool] | None:
"""Returns any available tool calls as their `AnthropicTool` definition.
Expand All @@ -125,7 +127,7 @@ def tools(self) -> list[AnthropicTool] | None:
return extracted_tools

@computed_field
@property
@cached_property
def tool(self) -> AnthropicTool | None:
"""Returns the 0th tool for the 0th choice message.
Expand Down
30 changes: 17 additions & 13 deletions mirascope/core/azure/_utils/_setup_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@
UserMessage,
)
from azure.core.credentials import AzureKeyCredential
from pydantic import BaseModel

from ...base import BaseMessageParam, BaseTool, _utils
from ...base._utils import AsyncCreateFn, CreateFn, get_async_create_fn, get_create_fn
from ...base._utils import (
DEFAULT_TOOL_DOCSTRING,
AsyncCreateFn,
CreateFn,
get_async_create_fn,
get_create_fn,
)
from ...base.call_params import CommonCallParams
from ...base.stream_config import StreamConfig
from .._call_kwargs import AzureCallKwargs
Expand All @@ -41,7 +48,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AzureCallParams | CommonCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
AsyncCreateFn[ChatCompletions, StreamingChatCompletionsUpdate],
Expand All @@ -63,7 +70,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AzureCallParams | CommonCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
CreateFn[ChatCompletions, StreamingChatCompletionsUpdate],
Expand All @@ -85,7 +92,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AzureCallParams | CommonCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
CreateFn[ChatCompletions, StreamingChatCompletionsUpdate]
Expand All @@ -108,25 +115,22 @@ def setup_call(
messages = cast(list[BaseMessageParam | ChatRequestMessage], messages)
messages = convert_message_params(messages)
if json_mode:
if tool_types and tool_types[0].model_config.get("strict", False):
if response_model and response_model.model_config.get("strict", False):
call_kwargs["response_format"] = ChatCompletionsResponseFormatJSON(
{
"name": tool_types[0]._name(),
"description": tool_types[0]._description(),
"name": response_model.__name__,
"description": response_model.__doc__ or DEFAULT_TOOL_DOCSTRING,
"strict": True,
"schema": tool_types[0].model_json_schema(
"schema": response_model.model_json_schema(
schema_generator=GenerateAzureStrictToolJsonSchema
),
}
)
else:
call_kwargs["response_format"] = ChatCompletionsResponseFormatJSON()
json_mode_content = _utils.json_mode_content(
tool_types[0] if tool_types else None
).strip()
json_mode_content = _utils.json_mode_content(response_model).strip()
messages.append(UserMessage(content=json_mode_content))
call_kwargs.pop("tools", None)
elif extract:
elif response_model:
assert tool_types, "At least one tool must be provided for extraction."
if tool_types and tool_types[0].model_config.get("strict", False):
warnings.warn(
Expand Down
8 changes: 5 additions & 3 deletions mirascope/core/azure/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
usage docs: learn/calls.md#handling-responses
"""

from functools import cached_property

from azure.ai.inference.models import (
AssistantMessage,
ChatCompletions,
Expand Down Expand Up @@ -101,7 +103,7 @@ def cost(self) -> float | None:
return calculate_cost(self.input_tokens, self.output_tokens, self.model)

@computed_field
@property
@cached_property
def message_param(self) -> SerializeAsAny[AssistantMessage]:
"""Returns the assistants's response as a message parameter."""
message_param = self.response.choices[0].message
Expand All @@ -110,7 +112,7 @@ def message_param(self) -> SerializeAsAny[AssistantMessage]:
)

@computed_field
@property
@cached_property
def tools(self) -> list[AzureTool] | None:
"""Returns any available tool calls as their `AzureTool` definition.
Expand All @@ -134,7 +136,7 @@ def tools(self) -> list[AzureTool] | None:
return extracted_tools

@computed_field
@property
@cached_property
def tool(self) -> AzureTool | None:
"""Returns the 0th tool for the 0th choice message.
Expand Down
1 change: 1 addition & 0 deletions mirascope/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"CacheControlPart",
"call_factory",
"CommonCallParams",
"DocumentPart",
"FromCallArgs",
"GenerateJsonSchemaNoTitles",
"ImagePart",
Expand Down
16 changes: 16 additions & 0 deletions mirascope/core/base/_call_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ._create import create_factory
from ._extract import extract_factory
from ._extract_with_tools import extract_with_tools_factory
from ._utils import (
BaseType,
GetJsonOutput,
Expand Down Expand Up @@ -192,6 +193,20 @@ def base_call(
client=client,
call_params=call_params,
) # pyright: ignore [reportReturnType, reportCallIssue]
elif tools:
return partial(
extract_with_tools_factory(
TCallResponse=TCallResponse,
setup_call=setup_call,
get_json_output=get_json_output,
),
model=model,
tools=tools,
response_model=response_model,
output_parser=output_parser,
client=client,
call_params=call_params,
) # pyright: ignore [reportReturnType, reportCallIssue]
else:
return partial(
extract_factory(
Expand Down Expand Up @@ -228,6 +243,7 @@ def base_call(
create_factory(TCallResponse=TCallResponse, setup_call=setup_call),
model=model,
tools=tools,
response_model=None,
output_parser=output_parser,
json_mode=json_mode,
client=client,
Expand Down
11 changes: 9 additions & 2 deletions mirascope/core/base/_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from functools import wraps
from typing import Any, ParamSpec, TypeVar, cast, overload

from pydantic import BaseModel

from ._utils import (
SameSyncAndAsyncClientSetupCall,
SetupCall,
Expand Down Expand Up @@ -72,6 +74,7 @@ def decorator(
fn: Callable[_P, _BaseDynamicConfigT],
model: str,
tools: list[type[BaseTool] | Callable] | None,
response_model: type[BaseModel] | None,
output_parser: Callable[[_BaseCallResponseT], _ParsedOutputT] | None,
json_mode: bool,
client: _SameSyncAndAsyncClientT | _SyncBaseClientT | None,
Expand All @@ -83,6 +86,7 @@ def decorator(
fn: Callable[_P, Messages.Type],
model: str,
tools: list[type[BaseTool] | Callable] | None,
response_model: type[BaseModel] | None,
output_parser: Callable[[_BaseCallResponseT], _ParsedOutputT] | None,
json_mode: bool,
client: _SameSyncAndAsyncClientT | _SyncBaseClientT | None,
Expand All @@ -98,6 +102,7 @@ def decorator(
],
model: str,
tools: list[type[BaseTool] | Callable] | None,
response_model: type[BaseModel] | None,
output_parser: Callable[[_BaseCallResponseT], _ParsedOutputT] | None,
json_mode: bool,
client: _SameSyncAndAsyncClientT | _AsyncBaseClientT | None,
Expand All @@ -112,6 +117,7 @@ def decorator(
fn: Callable[_P, Awaitable[Messages.Type] | Coroutine[Any, Any, Messages.Type]],
model: str,
tools: list[type[BaseTool] | Callable] | None,
response_model: type[BaseModel] | None,
output_parser: Callable[[_BaseCallResponseT], _ParsedOutputT] | None,
json_mode: bool,
client: _SameSyncAndAsyncClientT | _AsyncBaseClientT | None,
Expand All @@ -132,6 +138,7 @@ def decorator(
| Callable[_P, Awaitable[Messages.Type] | Coroutine[Any, Any, Messages.Type]],
model: str,
tools: list[type[BaseTool] | Callable] | None,
response_model: type[BaseModel] | None,
output_parser: Callable[[_BaseCallResponseT], _ParsedOutputT] | None,
json_mode: bool,
client: _SameSyncAndAsyncClientT | _AsyncBaseClientT | _SyncBaseClientT | None,
Expand Down Expand Up @@ -174,7 +181,7 @@ async def inner_async(
tools=tools,
json_mode=json_mode,
call_params=call_params,
extract=False,
response_model=response_model,
stream=False,
)
start_time = datetime.datetime.now().timestamp() * 1000
Expand Down Expand Up @@ -218,7 +225,7 @@ def inner(
tools=tools,
json_mode=json_mode,
call_params=call_params,
extract=False,
response_model=response_model,
stream=False,
)
start_time = datetime.datetime.now().timestamp() * 1000
Expand Down
9 changes: 6 additions & 3 deletions mirascope/core/base/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def extract_factory( # noqa: ANN202
):
"""Returns the wrapped function with the provider specific interfaces."""
create_decorator = create_factory(
TCallResponse=TCallResponse, setup_call=setup_call
TCallResponse=TCallResponse,
setup_call=setup_call,
)

@overload
Expand Down Expand Up @@ -110,10 +111,12 @@ def decorator(
]:
fn._model = model # pyright: ignore [reportFunctionMemberAccess]
fn.__mirascope_call__ = True # pyright: ignore [reportFunctionMemberAccess]
tool = setup_extract_tool(response_model, TToolType)
create_decorator_kwargs = {
"model": model,
"tools": [tool],
"tools": [setup_extract_tool(response_model, TToolType)]
if not json_mode
else None,
"response_model": response_model,
"output_parser": None,
"json_mode": json_mode,
"client": client,
Expand Down
Loading

0 comments on commit fee724b

Please sign in to comment.