diff --git a/python/instrumentation/openinference-instrumentation-smolagents/tests/openinference/instrumentation/smolagents/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-smolagents/tests/openinference/instrumentation/smolagents/test_instrumentor.py index a7ef24a54..d12f86cd2 100644 --- a/python/instrumentation/openinference-instrumentation-smolagents/tests/openinference/instrumentation/smolagents/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-smolagents/tests/openinference/instrumentation/smolagents/test_instrumentor.py @@ -11,6 +11,8 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from smolagents import OpenAIServerModel from smolagents.tools import Tool +from smolagents.agents import CodeAgent, ManagedAgent, ToolCallingAgent +from smoalgents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition from openinference.instrumentation.smolagents import SmolagentsInstrumentor from openinference.semconv.trace import ( @@ -254,6 +256,134 @@ def forward(self, location: str) -> str: assert json.loads(tool_call_arguments_json) == {"location": "Paris"} assert not attributes +class TestRun: + @pytest.mark.vcr( + decode_compressed_response=True, + before_record_request=remove_all_vcr_request_headers, + before_record_response=remove_all_vcr_response_headers, + ) + def test_multiagents(self): + class FakeModelMultiagentsManagerAgent: + def __call__( + self, + messages, + stop_sequences=None, + grammar=None, + tools_to_call_from=None, + ): + if tools_to_call_from is not None: + if len(messages) < 3: + return ChatMessage( + role="assistant", + content="", + tool_calls=[ + ChatMessageToolCall( + id="call_0", + type="function", + function=ChatMessageToolCallDefinition( + name="search_agent", + arguments="Who is the current US president?", + ), + ) + ], + ) + else: + assert "Report on the current US president" in str(messages) + return ChatMessage( + role="assistant", + content="", + tool_calls=[ + ChatMessageToolCall( + id="call_0", + type="function", + function=ChatMessageToolCallDefinition( + name="final_answer", arguments="Final report." + ), + ) + ], + ) + else: + if len(messages) < 3: + return ChatMessage( + role="assistant", + content=""" +Thought: Let's call our search agent. +Code: +```py +result = search_agent("Who is the current US president?") +``` +""", + ) + else: + assert "Report on the current US president" in str(messages) + return ChatMessage( + role="assistant", + content=""" +Thought: Let's return the report. +Code: +```py +final_answer("Final report.") +``` +""", + ) + + manager_model = FakeModelMultiagentsManagerAgent() + + class FakeModelMultiagentsManagedAgent: + def __call__( + self, + messages, + tools_to_call_from=None, + stop_sequences=None, + grammar=None, + ): + return ChatMessage( + role="assistant", + content="", + tool_calls=[ + ChatMessageToolCall( + id="call_0", + type="function", + function=ChatMessageToolCallDefinition( + name="final_answer", + arguments="Report on the current US president", + ), + ) + ], + ) + + managed_model = FakeModelMultiagentsManagedAgent() + + web_agent = ToolCallingAgent( + tools=[], + model=managed_model, + max_steps=10, + ) + + managed_web_agent = ManagedAgent( + agent=web_agent, + name="search_agent", + description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports", + ) + + manager_code_agent = CodeAgent( + tools=[], + model=manager_model, + managed_agents=[managed_web_agent], + additional_authorized_imports=["time", "numpy", "pandas"], + ) + + report = manager_code_agent.run("Fake question.") + assert report == "Final report." + + manager_toolcalling_agent = ToolCallingAgent( + tools=[], + model=manager_model, + managed_agents=[managed_web_agent], + ) + + report = manager_toolcalling_agent.run("Fake question.") + assert report == "Final report." # message attributes MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT