diff --git a/weave/integrations/smolagents/__init__.py b/weave/integrations/smolagents/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/weave/integrations/smolagents/smolagents_sdk.py b/weave/integrations/smolagents/smolagents_sdk.py new file mode 100644 index 000000000000..1cb1b7594db5 --- /dev/null +++ b/weave/integrations/smolagents/smolagents_sdk.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING, Callable, Any + +import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings +from weave.trace.op_extensions.accumulator import add_accumulator +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher +from weave.trace.serialize import dictify +from smolagents import PythonInterpreterTool + +_smolagents_patcher: MultiPatcher | None = None + + +def smolagents_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + if "self" in inputs: + inputs["self"] = dictify(inputs["self"]) + return inputs + + +def smolagents_wrapper(name: str) -> Callable[[Callable], Callable]: + def wrapper(fn: Callable) -> Callable: + op = weave.op(fn, postprocess_inputs=smolagents_postprocess_inputs, postprocess_output=lambda x: dictify(x)) + op.name = name # type: ignore + return op + + return wrapper + + +def get_smolagents_patcher(): + multi_step_agent_patcher = [ + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "MultiStepAgent.run", + smolagents_wrapper("smolagents.MultiStepAgent.run"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "MultiStepAgent.run", + smolagents_wrapper("smolagents.MultiStepAgent.direct_run"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "MultiStepAgent.run", + smolagents_wrapper("smolagents.MultiStepAgent.execute_tool_call"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "MultiStepAgent.run", + smolagents_wrapper("smolagents.MultiStepAgent.extract_action"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "MultiStepAgent.run", + smolagents_wrapper("smolagents.MultiStepAgent.planning_step"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "MultiStepAgent.run", + smolagents_wrapper("smolagents.MultiStepAgent.provide_final_answer"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "MultiStepAgent.run", + smolagents_wrapper("smolagents.MultiStepAgent.step"), + ), + ] + + additional_agents_patcher = [ + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "ToolCallingAgent.run", + smolagents_wrapper("smolagents.ToolCallingAgent.run"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "ToolCallingAgent.run", + smolagents_wrapper("smolagents.ToolCallingAgent.step"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "CodeAgent.run", + smolagents_wrapper("smolagents.CodeAgent.run"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "CodeAgent.run", + smolagents_wrapper("smolagents.CodeAgent.step"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "ManagedAgent.__call__", + smolagents_wrapper("smolagents.ManagedAgent"), + ), + ] + + models_patcher = [ + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "Model.__call__", + smolagents_wrapper("smolagents.Model"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "HfApiModel.__call__", + smolagents_wrapper("smolagents.HfApiModel"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "TransformersModel.__call__", + smolagents_wrapper("smolagents.TransformersModel"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "LiteLLMModel.__call__", + smolagents_wrapper("smolagents.LiteLLMModel"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "OpenAIServerModel.__call__", + smolagents_wrapper("smolagents.OpenAIServerModel"), + ), + ] + + tools_patcher = [ + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "Tool.__call__", + smolagents_wrapper("smolagents.Tool"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "Tool.forward", + smolagents_wrapper("smolagents.Tool.forward"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "PythonInterpreterTool.forward", + smolagents_wrapper("smolagents.PythonInterpreterTool.forward"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "FinalAnswerTool.forward", + smolagents_wrapper("smolagents.FinalAnswerTool.forward"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "UserInputTool.forward", + smolagents_wrapper("smolagents.UserInputTool.forward"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "DuckDuckGoSearchTool.forward", + smolagents_wrapper("smolagents.DuckDuckGoSearchTool.forward"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "GoogleSearchTool.forward", + smolagents_wrapper("smolagents.GoogleSearchTool.forward"), + ), + SymbolPatcher( + lambda: importlib.import_module("smolagents"), + "VisitWebpageTool.forward", + smolagents_wrapper("smolagents.VisitWebpageTool.forward"), + ), + ] + + return MultiPatcher( + [ + *multi_step_agent_patcher, + *additional_agents_patcher, + *models_patcher, + *tools_patcher, + ] + ) diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index bc77752957c2..f8b29e2025fe 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -47,6 +47,7 @@ class AutopatchSettings(BaseModel): openai: IntegrationSettings = Field(default_factory=IntegrationSettings) vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) chatnvidia: IntegrationSettings = Field(default_factory=IntegrationSettings) + # smolagents: IntegrationSettings = Field(default_factory=IntegrationSettings) @validate_call @@ -70,6 +71,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None: from weave.integrations.notdiamond.tracing import get_notdiamond_patcher from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher + from weave.integrations.smolagents.smolagents_sdk import get_smolagents_patcher if settings is None: settings = AutopatchSettings() @@ -87,6 +89,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None: get_notdiamond_patcher(settings.notdiamond).attempt_patch() get_vertexai_patcher(settings.vertexai).attempt_patch() get_nvidia_ai_patcher(settings.chatnvidia).attempt_patch() + get_smolagents_patcher().attempt_patch() llamaindex_patcher.attempt_patch() langchain_patcher.attempt_patch() @@ -112,6 +115,7 @@ def reset_autopatch() -> None: from weave.integrations.notdiamond.tracing import get_notdiamond_patcher from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import get_vertexai_patcher + from weave.integrations.smolagents.smolagents_sdk import get_smolagents_patcher get_openai_patcher().undo_patch() get_mistral_patcher().undo_patch() @@ -126,6 +130,7 @@ def reset_autopatch() -> None: get_notdiamond_patcher().undo_patch() get_vertexai_patcher().undo_patch() get_nvidia_ai_patcher().undo_patch() + get_smolagents_patcher().undo_patch() llamaindex_patcher.undo_patch() langchain_patcher.undo_patch()