Skip to content

Commit

Permalink
Merge pull request #743 from Mirascope/any-type-tool-output
Browse files Browse the repository at this point in the history
Any type tool output
  • Loading branch information
willbakst authored Dec 14, 2024
2 parents f60ac2a + fb459ac commit ff35c81
Show file tree
Hide file tree
Showing 16 changed files with 219 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
- id: ruff-format
files: "^mirascope|^tests|^examples|^docs"
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.389
rev: v1.1.390
hooks:
- id: pyright
- repo: local
Expand Down
3 changes: 2 additions & 1 deletion mirascope/core/anthropic/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from pydantic import SerializeAsAny, computed_field

from ..base import BaseCallResponse
from ..base import BaseCallResponse, transform_tool_outputs
from ._utils import calculate_cost
from .call_params import AnthropicCallParams
from .dynamic_config import AnthropicDynamicConfig, AsyncAnthropicDynamicConfig
Expand Down Expand Up @@ -137,6 +137,7 @@ def tool(self) -> AnthropicTool | None:
return None

@classmethod
@transform_tool_outputs
def tool_message_params(
cls, tools_and_outputs: list[tuple[AnthropicTool, str]]
) -> list[MessageParam]:
Expand Down
3 changes: 2 additions & 1 deletion mirascope/core/azure/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from pydantic import SerializeAsAny, SkipValidation, computed_field

from ..base import BaseCallResponse
from ..base import BaseCallResponse, transform_tool_outputs
from ._utils import calculate_cost
from .call_params import AzureCallParams
from .dynamic_config import AsyncAzureDynamicConfig, AzureDynamicConfig
Expand Down Expand Up @@ -158,6 +158,7 @@ def _get_tool_message(cls, tool: AzureTool, output: str) -> ToolMessage:
return tool_message

@classmethod
@transform_tool_outputs
def tool_message_params(
cls, tools_and_outputs: list[tuple[AzureTool, str]]
) -> list[ToolMessage]:
Expand Down
3 changes: 2 additions & 1 deletion mirascope/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ._utils import BaseType
from .call_kwargs import BaseCallKwargs
from .call_params import BaseCallParams, CommonCallParams
from .call_response import BaseCallResponse
from .call_response import BaseCallResponse, transform_tool_outputs
from .call_response_chunk import BaseCallResponseChunk
from .dynamic_config import BaseDynamicConfig
from .from_call_args import FromCallArgs
Expand Down Expand Up @@ -58,6 +58,7 @@
"TextPart",
"ToolConfig",
"toolkit_tool",
"transform_tool_outputs",
"_partial",
"_utils",
]
68 changes: 66 additions & 2 deletions mirascope/core/base/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

from __future__ import annotations

import base64
import json
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Generic, TypeVar
from collections.abc import Callable
from functools import wraps
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar

from pydantic import (
BaseModel,
Expand All @@ -14,6 +18,7 @@
field_serializer,
)

from ._utils import BaseType
from .call_kwargs import BaseCallKwargs
from .call_params import BaseCallParams
from .dynamic_config import BaseDynamicConfig
Expand All @@ -27,6 +32,64 @@
_MessageParamT = TypeVar("_MessageParamT", bound=Any)
_CallParamsT = TypeVar("_CallParamsT", bound=BaseCallParams)
_UserMessageParamT = TypeVar("_UserMessageParamT", bound=Any)
_BaseToolT = TypeVar("_BaseToolT", bound=BaseTool)
_BaseCallResponseT = TypeVar("_BaseCallResponseT", bound="BaseCallResponse")


JsonableType: TypeAlias = (
str
| int
| float
| bool
| bytes
| list["JsonableType"]
| set["JsonableType"]
| tuple["JsonableType", ...]
| dict[str, "JsonableType"]
| BaseModel
)


def transform_tool_outputs(
fn: Callable[[type[_BaseCallResponseT], list[tuple[_BaseToolT, str]]], list[Any]],
) -> Callable[
[type[_BaseCallResponseT], list[tuple[_BaseToolT, JsonableType]]],
list[Any],
]:
@wraps(fn)
def wrapper(
cls: type[_BaseCallResponseT],
tools_and_outputs: list[tuple[_BaseToolT, JsonableType]],
) -> list[Any]:
def recursive_serializer(value: JsonableType) -> BaseType:
if isinstance(value, str):
return value
if isinstance(value, int | float | bool):
return value # Don't serialize primitives yet
if isinstance(value, bytes):
return base64.b64encode(value).decode("utf-8")
if isinstance(value, BaseModel):
return value.model_dump()
if isinstance(value, list | set | tuple):
return [recursive_serializer(item) for item in value]
if isinstance(value, dict):
return {k: recursive_serializer(v) for k, v in value.items()}
raise TypeError(f"Unsupported type for serialization: {type(value)}")

transformed_tools_and_outputs = [
(
tool,
output.model_dump_json()
if isinstance(output, BaseModel)
else str(recursive_serializer(output))
if isinstance(output, str | bytes)
else json.dumps(recursive_serializer(output)),
)
for tool, output in tools_and_outputs
]
return fn(cls, transformed_tools_and_outputs)

return wrapper


class BaseCallResponse(
Expand Down Expand Up @@ -180,8 +243,9 @@ def tool(self) -> _BaseToolT | None:

@classmethod
@abstractmethod
@transform_tool_outputs
def tool_message_params(
cls, tools_and_outputs: list[tuple[_BaseToolT, Any]]
cls, tools_and_outputs: list[tuple[_BaseToolT, str]]
) -> list[Any]:
"""Returns the tool message parameters for tool call results.
Expand Down
4 changes: 2 additions & 2 deletions mirascope/core/base/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from .call_kwargs import BaseCallKwargs
from .call_params import BaseCallParams
from .call_response import BaseCallResponse
from .call_response import BaseCallResponse, JsonableType
from .call_response_chunk import BaseCallResponseChunk
from .dynamic_config import BaseDynamicConfig
from .messages import Messages
Expand Down Expand Up @@ -217,7 +217,7 @@ def _construct_message_param(
...

def tool_message_params(
self, tools_and_outputs: list[tuple[_BaseToolT, str]]
self, tools_and_outputs: list[tuple[_BaseToolT, JsonableType]]
) -> list[_ToolMessageParamT]:
"""Returns the tool message parameters for tool call results.
Expand Down
3 changes: 2 additions & 1 deletion mirascope/core/bedrock/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ConverseResponseTypeDef as AsyncConverseResponseTypeDef,
)

from ..base import BaseCallResponse
from ..base import BaseCallResponse, transform_tool_outputs
from ._call_kwargs import BedrockCallKwargs
from ._types import (
AssistantMessageTypeDef,
Expand Down Expand Up @@ -178,6 +178,7 @@ def tool(self) -> BedrockTool | None:
return None

@classmethod
@transform_tool_outputs
def tool_message_params(
cls, tools_and_outputs: list[tuple[BedrockTool, str]]
) -> list[ToolResultBlockMessageTypeDef]:
Expand Down
3 changes: 2 additions & 1 deletion mirascope/core/cohere/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from pydantic import SkipValidation, computed_field

from ..base import BaseCallResponse
from ..base import BaseCallResponse, transform_tool_outputs
from ._utils import calculate_cost
from .call_params import CohereCallParams
from .dynamic_config import AsyncCohereDynamicConfig, CohereDynamicConfig
Expand Down Expand Up @@ -146,6 +146,7 @@ def tool(self) -> CohereTool | None:
return None

@classmethod
@transform_tool_outputs
def tool_message_params(
cls,
tools_and_outputs: list[tuple[CohereTool, str]],
Expand Down
5 changes: 3 additions & 2 deletions mirascope/core/gemini/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from pydantic import computed_field

from ..base import BaseCallResponse
from ..base import BaseCallResponse, transform_tool_outputs
from ._utils import calculate_cost
from .call_params import GeminiCallParams
from .dynamic_config import GeminiDynamicConfig
Expand Down Expand Up @@ -156,8 +156,9 @@ def tool(self) -> GeminiTool | None:
return None

@classmethod
@transform_tool_outputs
def tool_message_params(
cls, tools_and_outputs: list[tuple[GeminiTool, object]]
cls, tools_and_outputs: list[tuple[GeminiTool, str]]
) -> list[ContentDict]:
"""Returns the tool message parameters for tool call results.
Expand Down
3 changes: 2 additions & 1 deletion mirascope/core/groq/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from groq.types.completion_usage import CompletionUsage
from pydantic import SerializeAsAny, computed_field

from ..base import BaseCallResponse
from ..base import BaseCallResponse, transform_tool_outputs
from ._utils import calculate_cost
from .call_params import GroqCallParams
from .dynamic_config import AsyncGroqDynamicConfig, GroqDynamicConfig
Expand Down Expand Up @@ -140,6 +140,7 @@ def tool(self) -> GroqTool | None:
return None

@classmethod
@transform_tool_outputs
def tool_message_params(
cls, tools_and_outputs: list[tuple[GroqTool, str]]
) -> list[ChatCompletionToolMessageParam]:
Expand Down
3 changes: 2 additions & 1 deletion mirascope/core/mistral/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from pydantic import computed_field

from ..base import BaseCallResponse
from ..base import BaseCallResponse, transform_tool_outputs
from ._utils import calculate_cost
from .call_params import MistralCallParams
from .dynamic_config import MistralDynamicConfig
Expand Down Expand Up @@ -147,6 +147,7 @@ def tool(self) -> MistralTool | None:
return None

@classmethod
@transform_tool_outputs
def tool_message_params(
cls, tools_and_outputs: list[tuple[MistralTool, str]]
) -> list[ToolMessage]:
Expand Down
3 changes: 2 additions & 1 deletion mirascope/core/openai/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from openai.types.completion_usage import CompletionUsage
from pydantic import SerializeAsAny, SkipValidation, computed_field

from ..base import BaseCallResponse
from ..base import BaseCallResponse, transform_tool_outputs
from ._utils import calculate_cost
from .call_params import OpenAICallParams
from .dynamic_config import OpenAIDynamicConfig
Expand Down Expand Up @@ -169,6 +169,7 @@ def tool(self) -> OpenAITool | None:
return None

@classmethod
@transform_tool_outputs
def tool_message_params(
cls, tools_and_outputs: list[tuple[OpenAITool, str]]
) -> list[ChatCompletionToolMessageParam]:
Expand Down
5 changes: 3 additions & 2 deletions mirascope/core/vertex/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import computed_field
from vertexai.generative_models import Content, GenerationResponse, Part, Tool

from ..base import BaseCallResponse
from ..base import BaseCallResponse, transform_tool_outputs
from ._utils import calculate_cost
from .call_params import VertexCallParams
from .dynamic_config import VertexDynamicConfig
Expand Down Expand Up @@ -146,8 +146,9 @@ def tool(self) -> VertexTool | None:
return None

@classmethod
@transform_tool_outputs
def tool_message_params(
cls, tools_and_outputs: list[tuple[VertexTool, object]]
cls, tools_and_outputs: list[tuple[VertexTool, str]]
) -> list[Content]:
"""Returns the tool message parameters for tool call results.
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "mirascope"
version = "1.12.1"
version = "1.13.0"
description = "LLM abstractions that aren't obstructions"
readme = "README.md"
license = { file = "LICENSE" }
Expand Down Expand Up @@ -89,7 +89,7 @@ mcp = ["mcp>=1.0.0"]
dev-dependencies = [
"ruff>=0.6.1",
"pytest>=8.3.2",
"pyright>=1.1.389",
"pyright>=1.1.390",
"pytest-asyncio>=0.23.8",
"pytest-cov>=5.0.0",
"pre-commit>=3.8.0",
Expand Down
Loading

0 comments on commit ff35c81

Please sign in to comment.