Skip to content

Commit

Permalink
add agent
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Jan 19, 2025
1 parent 018f2ee commit a4f5881
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 14 deletions.
50 changes: 50 additions & 0 deletions python/openinference-instrumentation/examples/tracer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,56 @@
" decorated_chain_with_context_attributes(\"input\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agents"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Context Managers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with tracer.start_as_current_span(\n",
" \"agent-span-with-plain-text-io\",\n",
" openinference_span_kind=\"agent\",\n",
") as span:\n",
" span.set_input(\"input\")\n",
" span.set_output(\"output\")\n",
" span.set_status(Status(StatusCode.OK))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Decorators"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@tracer.agent\n",
"def decorated_agent(input: str) -> str:\n",
" return \"output\"\n",
"\n",
"\n",
"decorated_agent(\"input\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,12 @@ def set_output(
self.set_attributes(get_output_value_and_mime_type(value, mime_type))


class AgentSpan(OpenInferenceSpan):
def __init__(self, wrapped: Span, config: TraceConfig) -> None:
super().__init__(wrapped, config)
self.__wrapped__.set_attributes(get_span_kind(OpenInferenceSpanKindValues.AGENT))


class ChainSpan(OpenInferenceSpan):
def __init__(self, wrapped: Span, config: TraceConfig) -> None:
super().__init__(wrapped, config)
Expand Down Expand Up @@ -768,8 +774,41 @@ def start_span(
span.set_attributes(dict(get_attributes_from_context()))
return span

# overload for @tracer.chain usage (no parameters)
@overload
@overload # for @tracer.agent usage (no parameters)
def agent(
self,
wrapped_function: Callable[ParametersType, ReturnType],
/,
*,
name: None = None,
) -> Callable[ParametersType, ReturnType]: ...

@overload # for @tracer.agent(name="name") usage (with parameters)
def agent(
self,
wrapped_function: None = None,
/,
*,
name: Optional[str] = None,
) -> Callable[[Callable[ParametersType, ReturnType]], Callable[ParametersType, ReturnType]]: ...

def agent(
self,
wrapped_function: Optional[Callable[ParametersType, ReturnType]] = None,
/,
*,
name: Optional[str] = None,
) -> Union[
Callable[ParametersType, ReturnType],
Callable[[Callable[ParametersType, ReturnType]], Callable[ParametersType, ReturnType]],
]:
return self._chain(
wrapped_function,
kind=OpenInferenceSpanKindValues.AGENT, # chains and agents differ only in span kind
name=name,
)

@overload # for @tracer.chain usage (no parameters)
def chain(
self,
wrapped_function: Callable[ParametersType, ReturnType],
Expand All @@ -778,8 +817,7 @@ def chain(
name: None = None,
) -> Callable[ParametersType, ReturnType]: ...

# overload for @tracer.chain(name="name") usage (with parameters)
@overload
@overload # for @tracer.chain(name="name") usage (with parameters)
def chain(
self,
wrapped_function: None = None,
Expand All @@ -798,13 +836,14 @@ def chain(
Callable[ParametersType, ReturnType],
Callable[[Callable[ParametersType, ReturnType]], Callable[ParametersType, ReturnType]],
]:
return self._chain(wrapped_function, name=name)
return self._chain(wrapped_function, kind=OpenInferenceSpanKindValues.CHAIN, name=name)

def _chain(
self,
wrapped_function: Optional[Callable[ParametersType, ReturnType]] = None,
/,
*,
kind: OpenInferenceSpanKindValues,
name: Optional[str] = None,
) -> Union[
Callable[ParametersType, ReturnType],
Expand All @@ -830,7 +869,7 @@ def sync_wrapper(
input_attributes = get_input_value_and_mime_type(value=arguments)
with tracer.start_as_current_span(
span_name,
openinference_span_kind=OpenInferenceSpanKindValues.CHAIN,
openinference_span_kind=kind,
attributes=input_attributes,
) as span:
output = wrapped(*args, **kwargs)
Expand Down Expand Up @@ -866,7 +905,7 @@ async def async_wrapper(
input_attributes = get_input_value_and_mime_type(value=arguments)
with tracer.start_as_current_span(
span_name,
openinference_span_kind=OpenInferenceSpanKindValues.CHAIN,
openinference_span_kind=kind,
attributes=input_attributes,
) as span:
output = await wrapped(*args, **kwargs)
Expand All @@ -886,12 +925,9 @@ async def async_wrapper(
if asyncio.iscoroutinefunction(wrapped_function):
return async_wrapper(wrapped_function) # type: ignore[no-any-return]
return sync_wrapper(wrapped_function) # type: ignore[no-any-return]
if asyncio.iscoroutinefunction(wrapped_function):
return lambda x: async_wrapper(x)
return lambda x: sync_wrapper(x)
return lambda f: async_wrapper(f) if asyncio.iscoroutinefunction(f) else sync_wrapper(f)

# overload for @tool usage (no parameters)
@overload
@overload # for @tracer.tool usage (no parameters)
def tool(
self,
wrapped_function: Callable[ParametersType, ReturnType],
Expand All @@ -901,8 +937,7 @@ def tool(
description: Optional[str] = None,
) -> Callable[ParametersType, ReturnType]: ...

# overload for @tool(name="name") usage (with parameters)
@overload
@overload # for @tracer.tool(name="name") usage (with parameters)
def tool(
self,
wrapped_function: None = None,
Expand Down Expand Up @@ -1064,6 +1099,8 @@ def _normalize_openinference_span_kind(kind: OpenInferenceSpanKind) -> OpenInfer


def _get_span_wrapper_cls(kind: OpenInferenceSpanKindValues) -> Type[OpenInferenceSpan]:
if kind is OpenInferenceSpanKindValues.AGENT:
return AgentSpan
if kind is OpenInferenceSpanKindValues.CHAIN:
return ChainSpan
if kind is OpenInferenceSpanKindValues.TOOL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,33 @@ class OutputModel:
assert json.loads(output_value) == {"string_output": "output", "int_output": 2}
assert not attributes

def test_agent(
self,
in_memory_span_exporter: InMemorySpanExporter,
tracer: OITracer,
) -> None:
with tracer.start_as_current_span(
"agent-span-name",
openinference_span_kind="agent",
) as agent_span:
agent_span.set_input("input")
agent_span.set_output("output")
agent_span.set_status(Status(StatusCode.OK))

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
assert span.name == "agent-span-name"
assert span.status.is_ok
assert not span.events
attributes = dict(span.attributes or {})
assert attributes.pop(OPENINFERENCE_SPAN_KIND) == AGENT
assert attributes.pop(INPUT_MIME_TYPE) == TEXT
assert attributes.pop(INPUT_VALUE) == "input"
assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT
assert attributes.pop(OUTPUT_VALUE) == "output"
assert not attributes

def test_tool(
self,
in_memory_span_exporter: InMemorySpanExporter,
Expand Down Expand Up @@ -714,6 +741,58 @@ def decorated_chain_with_session(input: str) -> str:
assert attributes[SESSION_ID] == "123"


class TestAgentDecorator:
def test_plain_text_input_and_output(
self,
in_memory_span_exporter: InMemorySpanExporter,
tracer: OITracer,
) -> None:
@tracer.agent
def decorated_agent(input: str) -> str:
return "output"

decorated_agent("input")

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
assert span.name == "decorated_agent"
assert span.status.is_ok
assert not span.events
attributes = dict(span.attributes or {})
assert attributes.pop(OPENINFERENCE_SPAN_KIND) == AGENT
assert attributes.pop(INPUT_MIME_TYPE) == TEXT
assert attributes.pop(INPUT_VALUE) == "input"
assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT
assert attributes.pop(OUTPUT_VALUE) == "output"
assert not attributes

async def test_async_with_overridden_name(
self,
in_memory_span_exporter: InMemorySpanExporter,
tracer: OITracer,
) -> None:
@tracer.agent(name="custom-name")
async def decorated_agent(input: str) -> str:
return "output"

await decorated_agent("input")

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
assert span.name == "custom-name"
assert span.status.is_ok
assert not span.events
attributes = dict(span.attributes or {})
assert attributes.pop(OPENINFERENCE_SPAN_KIND) == AGENT
assert attributes.pop(INPUT_MIME_TYPE) == TEXT
assert attributes.pop(INPUT_VALUE) == "input"
assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT
assert attributes.pop(OUTPUT_VALUE) == "output"
assert not attributes


class TestTracerToolDecorator:
def test_tool_with_one_argument_and_docstring(
self,
Expand Down Expand Up @@ -970,6 +1049,7 @@ def tool_function(input_str: str) -> str:
JSON = OpenInferenceMimeTypeValues.JSON.value

# span kinds
AGENT = OpenInferenceSpanKindValues.AGENT.value
CHAIN = OpenInferenceSpanKindValues.CHAIN.value
TOOL = OpenInferenceSpanKindValues.TOOL.value

Expand Down

0 comments on commit a4f5881

Please sign in to comment.