Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(autogen2): basic tool calling #1216

Merged
merged 4 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from importlib import import_module

import autogen
from openinference.instrumentation.autogen import AutogenInstrumentor
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor

from openinference.instrumentation.autogen import AutogenInstrumentor


def main():
trace_provider = TracerProvider()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import json

import autogen
from autogen import ConversableAgent
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Context, Status, StatusCode
from opentelemetry.trace import Status, StatusCode

from autogen import ConversableAgent


class AutogenInstrumentor:
def __init__(self):
self.tracer = trace.get_tracer(__name__)
self._original_generate = None
self._original_initiate_chat = None
self._original_execute_function = None

def _safe_json_dumps(self, obj):
try:
Expand All @@ -22,6 +22,7 @@ def _safe_json_dumps(self, obj):
def instrument(self):
self._original_generate = ConversableAgent.generate_reply
self._original_initiate_chat = ConversableAgent.initiate_chat
self._original_execute_function = ConversableAgent.execute_function
instrumentor = self

def wrapped_generate(self, messages=None, sender=None, **kwargs):
Expand Down Expand Up @@ -95,9 +96,75 @@ def wrapped_initiate_chat(self, recipient, *args, **kwargs):
span.record_exception(e)
raise

def wrapped_execute_function(self, func_call, call_id=None, verbose=False):
try:
current_context = trace.get_current_span().get_span_context()

# Handle both dictionary and string inputs
if isinstance(func_call, str):
function_name = func_call
func_call = {"name": function_name}
else:
function_name = func_call.get("name", "unknown")

with instrumentor.tracer.start_as_current_span(
f"{function_name}",
context=trace.set_span_in_context(trace.get_current_span()),
links=[trace.Link(current_context)],
) as span:
span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, "TOOL")
span.set_attribute(SpanAttributes.TOOL_NAME, function_name)

# Record input
span.set_attribute(
SpanAttributes.INPUT_VALUE, instrumentor._safe_json_dumps(func_call)
)
span.set_attribute(SpanAttributes.INPUT_MIME_TYPE, "application/json")

# Record tool-specific attributes
if hasattr(self, "_function_map") and function_name in self._function_map:
func = self._function_map[function_name]
if hasattr(func, "__annotations__"):
span.set_attribute(
SpanAttributes.TOOL_PARAMETERS,
instrumentor._safe_json_dumps(func.__annotations__),
)

# Record function call details
if isinstance(func_call, dict):
# Record function arguments
if "arguments" in func_call:
span.set_attribute(
SpanAttributes.TOOL_CALL_FUNCTION_ARGUMENTS,
instrumentor._safe_json_dumps(func_call["arguments"]),
)

# Record function name
span.set_attribute(SpanAttributes.TOOL_CALL_FUNCTION_NAME, function_name)

# Execute function
result = instrumentor._original_execute_function(
self, func_call, call_id=call_id, verbose=verbose
)

# Record output
span.set_attribute(
SpanAttributes.OUTPUT_VALUE, instrumentor._safe_json_dumps(result)
)
span.set_attribute(SpanAttributes.OUTPUT_MIME_TYPE, "application/json")

return result

except Exception as e:
if "span" in locals():
span.set_status(Status(StatusCode.ERROR))
span.record_exception(e)
raise

# Replace methods
ConversableAgent.generate_reply = wrapped_generate
ConversableAgent.initiate_chat = wrapped_initiate_chat
ConversableAgent.execute_function = wrapped_execute_function

return self

Expand All @@ -106,8 +173,10 @@ def uninstrument(self):
if self._original_generate and self._original_initiate_chat:
ConversableAgent.generate_reply = self._original_generate
ConversableAgent.initiate_chat = self._original_initiate_chat
ConversableAgent.execute_function = self._original_execute_function
self._original_generate = None
self._original_initiate_chat = None
self._original_execute_function = None
return self


Expand Down
Loading