Skip to content

Commit

Permalink
add smolagents tests for llm (#2)
Browse files Browse the repository at this point in the history
* initial test for llm spans

* llm input messages

* llm tool definitions

* llm tool calls

* record tests

* nix encoder
  • Loading branch information
axiomofjoy authored Jan 13, 2025
1 parent 8c71f4a commit 2784a6b
Show file tree
Hide file tree
Showing 13 changed files with 551 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
SmolagentsInstrumentor()._instrument(tracer_provider=trace_provider)

from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
SmolagentsInstrumentor()._instrument(tracer_provider=trace_provider)


agent = ToolCallingAgent(tools=[DuckDuckGoSearchTool()], model=HfApiModel(), max_steps=3)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os

from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
SimpleSpanProcessor,
)
from smolagents import OpenAIServerModel

from openinference.instrumentation.smolagents import SmolagentsInstrumentor

endpoint = "http://0.0.0.0:6006/v1/traces"
trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)

model = OpenAIServerModel(
model_id="gpt-4o", api_key=os.environ["OPENAI_API_KEY"], api_base="https://api.openai.com/v1"
)
output = model(messages=[{"role": "user", "content": "hello world"}])
print(output)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os

from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
SimpleSpanProcessor,
)
from smolagents import OpenAIServerModel
from smolagents.tools import Tool

from openinference.instrumentation.smolagents import SmolagentsInstrumentor

endpoint = "http://0.0.0.0:6006/v1/traces"
trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)


class GetWeatherTool(Tool):
name = "get_weather"
description = "Get the weather for a given city"
inputs = {"location": {"type": "string", "description": "The city to get the weather for"}}
output_type = "string"

def forward(self, location: str) -> str:
return "sunny"


model = OpenAIServerModel(
model_id="gpt-4o", api_key=os.environ["OPENAI_API_KEY"], api_base="https://api.openai.com/v1"
)
output_message = model(
messages=[
{
"role": "user",
"content": "What is the weather in Paris?",
}
],
tools_to_call_from=[GetWeatherTool()],
)
print(output_message)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
SmolagentsInstrumentor()._instrument(tracer_provider=trace_provider)

knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
knowledge_base = knowledge_base.filter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
SmolagentsInstrumentor()._instrument(tracer_provider=trace_provider)

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
SmolagentsInstrumentor()._instrument(tracer_provider=trace_provider)
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)

# Choose which LLM engine to use!
# model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@ dependencies = [

[project.optional-dependencies]
instruments = [
"smolagents >= 1.1.0",
"smolagents>=1.2.2",
]
test = [
"smolagents == 1.1.0",
"smolagents>=1.2.2",
"opentelemetry-sdk",
"responses",
"vcrpy",
"pytest-recording",
]

[project.urls]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from openinference.instrumentation.smolagents.version import __version__

_instruments = ("smolagents >= 1.1.0",)
_instruments = ("smolagents >= 1.2.2",)

logger = logging.getLogger(__name__)

Expand All @@ -28,7 +28,7 @@ class SmolagentsInstrumentor(BaseInstrumentor): # type: ignore
"_original_run",
"_original_step",
"_original_tool_call",
"_original_model",
"_original_model_calls",
"_tracer",
)

Expand Down Expand Up @@ -76,8 +76,10 @@ def _instrument(self, **kwargs: Any) -> None:
from smolagents import Model

model_subclasses = Model.__subclasses__()
self._original_model_calls = {}
for model_subclass in model_subclasses:
model_subclass_wrapper = _ModelWrapper(tracer=self._tracer)
self._original_model_calls[model_subclass] = getattr(model_subclass, "__call__")
wrap_function_wrapper(
module="smolagents",
name=model_subclass.__name__ + ".__call__",
Expand All @@ -103,10 +105,11 @@ def _uninstrument(self, **kwargs: Any) -> None:
smolagents_module.MultiStepAgent.step = self._original_step
self._original_step = None

if self._original_model_generate is not None:
if self._original_model_calls is not None:
smolagents_module = import_module("smolagents.models")
smolagents_module.MultimodelAgent.model = self._original_model_generate
self._original_model = None
for model_subclass, original_model_call in self._original_model_calls.items():
setattr(model_subclass, "__call__", original_model_call)
self._original_model_calls = None

if self._original_tool_call is not None:
tool_usage_module = import_module("smolagents.tools")
Expand Down
Loading

0 comments on commit 2784a6b

Please sign in to comment.