Skip to content

Commit

Permalink
set up and pass ci for smolagents instrumentation and add examples (#1)
Browse files Browse the repository at this point in the history
* Add examples and pass CI
  • Loading branch information
axiomofjoy authored Jan 11, 2025
1 parent 62aac6d commit 8c71f4a
Show file tree
Hide file tree
Showing 13 changed files with 347 additions and 10,980 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from io import BytesIO

import requests
from PIL import Image
from smolagents import CodeAgent, GradioUI, HfApiModel, Tool
from smolagents.default_tools import VisitWebpageTool


class GetCatImageTool(Tool):
name = "get_cat_image"
description = "Get a cat image"
inputs = {}
output_type = "image"

def __init__(self):
super().__init__()
self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"

def forward(self):
response = requests.get(self.url)

return Image.open(BytesIO(response.content))


get_cat_image = GetCatImageTool()

agent = CodeAgent(
tools=[get_cat_image, VisitWebpageTool()],
model=HfApiModel(),
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
use_e2b_executor=True,
)

agent.run(
"Return me an image of a cat. Directly use the image provided in your state.",
additional_args={"cat_image": get_cat_image()},
)

GradioUI(agent).launch()
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, ManagedAgent, ToolCallingAgent

from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor

endpoint = "http://0.0.0.0:6006/v1/traces"
trace_provider = TracerProvider()
Expand All @@ -13,16 +12,24 @@
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
SmolagentsInstrumentor()._instrument(tracer_provider=trace_provider)

from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, ManagedAgent, ToolCallingAgent

agent = ToolCallingAgent(tools=[DuckDuckGoSearchTool()], model=HfApiModel(), max_steps=3)

managed_agent = ManagedAgent(
agent=agent,
name="managed_agent",
description="This is an agent that can do web search. When solving a task, ask him directly first, he gives good answers. Then you can double check."
description=(
"This is an agent that can do web search. "
"When solving a task, ask him directly first, he gives good answers. "
"Then you can double check."
),
)

manager_agent = CodeAgent(tools=[DuckDuckGoSearchTool()], model=HfApiModel(), managed_agents=[managed_agent])
manager_agent = CodeAgent(
tools=[DuckDuckGoSearchTool()], model=HfApiModel(), managed_agents=[managed_agent]
)

manager_agent.run("How many seconds would it take for a leopard at full speed to run through Pont des Arts? ASK YOUR MANAGED AGENT FOR LEOPARD SPEED FIRST")
manager_agent.run(
"How many seconds would it take for a leopard at full speed to run through Pont des Arts? "
"ASK YOUR MANAGED AGENT FOR LEOPARD SPEED FIRST"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os

import datasets
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
SimpleSpanProcessor,
)
from smolagents import CodeAgent, OpenAIServerModel, Tool

from openinference.instrumentation.smolagents import SmolagentsInstrumentor

endpoint = "http://0.0.0.0:6006/v1/traces"
trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
SmolagentsInstrumentor()._instrument(tracer_provider=trace_provider)

knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
knowledge_base = knowledge_base.filter(
lambda row: row["source"].startswith("huggingface/transformers")
)

source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
for doc in knowledge_base
]

text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
docs_processed = text_splitter.split_documents(source_docs)


class RetrieverTool(Tool):
name = "retriever"
description = (
"Uses semantic search to retrieve the parts of transformers documentation "
"that could be most relevant to answer your query."
)
inputs = {
"query": {
"type": "string",
"description": (
"The query to perform. "
"This should be semantically close to your target documents. "
"Use the affirmative form rather than a question."
),
}
}
output_type = "string"

def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
self.retriever = BM25Retriever.from_documents(docs, k=10)

def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"

docs = self.retriever.invoke(
query,
)
return "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {str(i)} =====\n" + doc.page_content
for i, doc in enumerate(docs)
]
)


retriever_tool = RetrieverTool(docs_processed)
agent = CodeAgent(
tools=[retriever_tool],
model=OpenAIServerModel(
"gpt-4o",
api_base="https://api.openai.com/v1",
api_key=os.environ["OPENAI_API_KEY"],
),
max_steps=4,
verbose=True,
)

agent_output = agent.run(
"For a transformers model training, which is slower, the forward or the backward pass?"
)

print("Final output:")
print(agent_output)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
datasets
langchain
langchain-community
opentelemetry-exporter-otlp
opentelemetry-exporter-otlp-proto-http
opentelemetry-sdk
rank_bm25
requests
sqlalchemy
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os

from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
SimpleSpanProcessor,
)
from smolagents import (
CodeAgent,
OpenAIServerModel,
tool,
)
from sqlalchemy import (
Column,
Float,
Integer,
MetaData,
String,
Table,
create_engine,
insert,
inspect,
text,
)

from openinference.instrumentation.smolagents import SmolagentsInstrumentor

endpoint = "http://0.0.0.0:6006/v1/traces"
trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
SmolagentsInstrumentor()._instrument(tracer_provider=trace_provider)

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# create city SQL table
table_name = "receipts"
receipts = Table(
table_name,
metadata_obj,
Column("receipt_id", Integer, primary_key=True),
Column("customer_name", String(16), primary_key=True),
Column("price", Float),
Column("tip", Float),
)
metadata_obj.create_all(engine)

rows = [
{"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
{"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
{"receipt_id": 3, "customer_name": "Woodrow Wilson", "price": 53.43, "tip": 5.43},
{"receipt_id": 4, "customer_name": "Margaret James", "price": 21.11, "tip": 1.00},
]
for row in rows:
stmt = insert(receipts).values(**row)
with engine.begin() as connection:
cursor = connection.execute(stmt)

inspector = inspect(engine)
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]

table_description = "Columns:\n" + "\n".join(
[f" - {name}: {col_type}" for name, col_type in columns_info]
)
print(table_description)


@tool
def sql_engine(query: str) -> str:
"""
Allows you to perform SQL queries on the table. Returns a string representation of the result.
The table is named 'receipts'. Its description is as follows:
Columns:
- receipt_id: INTEGER
- customer_name: VARCHAR(16)
- price: FLOAT
- tip: FLOAT
Args:
query: The query to perform. This should be correct SQL.
"""
output = ""
with engine.connect() as con:
rows = con.execute(text(query))
for row in rows:
output += "\n" + str(row)
return output


agent = CodeAgent(
tools=[sql_engine],
model=OpenAIServerModel(
"gpt-4o-mini",
api_base="https://api.openai.com/v1",
api_key=os.environ["OPENAI_API_KEY"],
),
)
agent.run("Can you give me the name of the client who got the most expensive receipt?")
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional

from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
SimpleSpanProcessor,
)
from smolagents import (
LiteLLMModel,
tool,
)
from smolagents.agents import ToolCallingAgent

from openinference.instrumentation.smolagents import SmolagentsInstrumentor

endpoint = "http://0.0.0.0:6006/v1/traces"
trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
SmolagentsInstrumentor()._instrument(tracer_provider=trace_provider)

# Choose which LLM engine to use!
# model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")

# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
model = LiteLLMModel(model_id="gpt-4o")


@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
celsius: the temperature
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"


agent = ToolCallingAgent(tools=[get_weather], model=model)

print(agent.run("What's the weather like in Paris?"))
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
TraceConfig,
)
from openinference.instrumentation.smolagents._wrappers import (
_ModelWrapper,
_RunWrapper,
_StepWrapper,
_ModelWrapper,
_ToolCallWrapper,
)
from openinference.instrumentation.smolagents.version import __version__
Expand All @@ -28,7 +28,7 @@ class SmolagentsInstrumentor(BaseInstrumentor): # type: ignore
"_original_run",
"_original_step",
"_original_tool_call",
"_original_model_call",
"_original_model",
"_tracer",
)

Expand Down Expand Up @@ -64,14 +64,17 @@ def _instrument(self, **kwargs: Any) -> None:
)

step_wrapper_tool_calling = _StepWrapper(tracer=self._tracer)
self._original_step = getattr(import_module("smolagents.agents").ToolCallingAgent, "step", None)
self._original_step = getattr(
import_module("smolagents.agents").ToolCallingAgent, "step", None
)
wrap_function_wrapper(
module="smolagents",
name="ToolCallingAgent.step",
wrapper=step_wrapper_tool_calling,
)

from smolagents import Model

model_subclasses = Model.__subclasses__()
for model_subclass in model_subclasses:
model_subclass_wrapper = _ModelWrapper(tracer=self._tracer)
Expand Down
Loading

0 comments on commit 8c71f4a

Please sign in to comment.