From a4f58810473a52ef1fab659cfab1002872490b87 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 18 Jan 2025 19:48:22 -0800 Subject: [PATCH] add agent --- .../examples/tracer.ipynb | 50 ++++++++++++ .../openinference/instrumentation/config.py | 65 +++++++++++---- .../tests/test_manual_instrumentation.py | 80 +++++++++++++++++++ 3 files changed, 181 insertions(+), 14 deletions(-) diff --git a/python/openinference-instrumentation/examples/tracer.ipynb b/python/openinference-instrumentation/examples/tracer.ipynb index 3920e7e32..535ecb888 100644 --- a/python/openinference-instrumentation/examples/tracer.ipynb +++ b/python/openinference-instrumentation/examples/tracer.ipynb @@ -412,6 +412,56 @@ " decorated_chain_with_context_attributes(\"input\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Agents" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Context Managers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with tracer.start_as_current_span(\n", + " \"agent-span-with-plain-text-io\",\n", + " openinference_span_kind=\"agent\",\n", + ") as span:\n", + " span.set_input(\"input\")\n", + " span.set_output(\"output\")\n", + " span.set_status(Status(StatusCode.OK))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Decorators" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@tracer.agent\n", + "def decorated_agent(input: str) -> str:\n", + " return \"output\"\n", + "\n", + "\n", + "decorated_agent(\"input\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/python/openinference-instrumentation/src/openinference/instrumentation/config.py b/python/openinference-instrumentation/src/openinference/instrumentation/config.py index 7f61c52fa..b7b8582f2 100644 --- a/python/openinference-instrumentation/src/openinference/instrumentation/config.py +++ b/python/openinference-instrumentation/src/openinference/instrumentation/config.py @@ -556,6 +556,12 @@ def set_output( self.set_attributes(get_output_value_and_mime_type(value, mime_type)) +class AgentSpan(OpenInferenceSpan): + def __init__(self, wrapped: Span, config: TraceConfig) -> None: + super().__init__(wrapped, config) + self.__wrapped__.set_attributes(get_span_kind(OpenInferenceSpanKindValues.AGENT)) + + class ChainSpan(OpenInferenceSpan): def __init__(self, wrapped: Span, config: TraceConfig) -> None: super().__init__(wrapped, config) @@ -768,8 +774,41 @@ def start_span( span.set_attributes(dict(get_attributes_from_context())) return span - # overload for @tracer.chain usage (no parameters) - @overload + @overload # for @tracer.agent usage (no parameters) + def agent( + self, + wrapped_function: Callable[ParametersType, ReturnType], + /, + *, + name: None = None, + ) -> Callable[ParametersType, ReturnType]: ... + + @overload # for @tracer.agent(name="name") usage (with parameters) + def agent( + self, + wrapped_function: None = None, + /, + *, + name: Optional[str] = None, + ) -> Callable[[Callable[ParametersType, ReturnType]], Callable[ParametersType, ReturnType]]: ... + + def agent( + self, + wrapped_function: Optional[Callable[ParametersType, ReturnType]] = None, + /, + *, + name: Optional[str] = None, + ) -> Union[ + Callable[ParametersType, ReturnType], + Callable[[Callable[ParametersType, ReturnType]], Callable[ParametersType, ReturnType]], + ]: + return self._chain( + wrapped_function, + kind=OpenInferenceSpanKindValues.AGENT, # chains and agents differ only in span kind + name=name, + ) + + @overload # for @tracer.chain usage (no parameters) def chain( self, wrapped_function: Callable[ParametersType, ReturnType], @@ -778,8 +817,7 @@ def chain( name: None = None, ) -> Callable[ParametersType, ReturnType]: ... - # overload for @tracer.chain(name="name") usage (with parameters) - @overload + @overload # for @tracer.chain(name="name") usage (with parameters) def chain( self, wrapped_function: None = None, @@ -798,13 +836,14 @@ def chain( Callable[ParametersType, ReturnType], Callable[[Callable[ParametersType, ReturnType]], Callable[ParametersType, ReturnType]], ]: - return self._chain(wrapped_function, name=name) + return self._chain(wrapped_function, kind=OpenInferenceSpanKindValues.CHAIN, name=name) def _chain( self, wrapped_function: Optional[Callable[ParametersType, ReturnType]] = None, /, *, + kind: OpenInferenceSpanKindValues, name: Optional[str] = None, ) -> Union[ Callable[ParametersType, ReturnType], @@ -830,7 +869,7 @@ def sync_wrapper( input_attributes = get_input_value_and_mime_type(value=arguments) with tracer.start_as_current_span( span_name, - openinference_span_kind=OpenInferenceSpanKindValues.CHAIN, + openinference_span_kind=kind, attributes=input_attributes, ) as span: output = wrapped(*args, **kwargs) @@ -866,7 +905,7 @@ async def async_wrapper( input_attributes = get_input_value_and_mime_type(value=arguments) with tracer.start_as_current_span( span_name, - openinference_span_kind=OpenInferenceSpanKindValues.CHAIN, + openinference_span_kind=kind, attributes=input_attributes, ) as span: output = await wrapped(*args, **kwargs) @@ -886,12 +925,9 @@ async def async_wrapper( if asyncio.iscoroutinefunction(wrapped_function): return async_wrapper(wrapped_function) # type: ignore[no-any-return] return sync_wrapper(wrapped_function) # type: ignore[no-any-return] - if asyncio.iscoroutinefunction(wrapped_function): - return lambda x: async_wrapper(x) - return lambda x: sync_wrapper(x) + return lambda f: async_wrapper(f) if asyncio.iscoroutinefunction(f) else sync_wrapper(f) - # overload for @tool usage (no parameters) - @overload + @overload # for @tracer.tool usage (no parameters) def tool( self, wrapped_function: Callable[ParametersType, ReturnType], @@ -901,8 +937,7 @@ def tool( description: Optional[str] = None, ) -> Callable[ParametersType, ReturnType]: ... - # overload for @tool(name="name") usage (with parameters) - @overload + @overload # for @tracer.tool(name="name") usage (with parameters) def tool( self, wrapped_function: None = None, @@ -1064,6 +1099,8 @@ def _normalize_openinference_span_kind(kind: OpenInferenceSpanKind) -> OpenInfer def _get_span_wrapper_cls(kind: OpenInferenceSpanKindValues) -> Type[OpenInferenceSpan]: + if kind is OpenInferenceSpanKindValues.AGENT: + return AgentSpan if kind is OpenInferenceSpanKindValues.CHAIN: return ChainSpan if kind is OpenInferenceSpanKindValues.TOOL: diff --git a/python/openinference-instrumentation/tests/test_manual_instrumentation.py b/python/openinference-instrumentation/tests/test_manual_instrumentation.py index 7c3368df0..aba1bb0a9 100644 --- a/python/openinference-instrumentation/tests/test_manual_instrumentation.py +++ b/python/openinference-instrumentation/tests/test_manual_instrumentation.py @@ -119,6 +119,33 @@ class OutputModel: assert json.loads(output_value) == {"string_output": "output", "int_output": 2} assert not attributes + def test_agent( + self, + in_memory_span_exporter: InMemorySpanExporter, + tracer: OITracer, + ) -> None: + with tracer.start_as_current_span( + "agent-span-name", + openinference_span_kind="agent", + ) as agent_span: + agent_span.set_input("input") + agent_span.set_output("output") + agent_span.set_status(Status(StatusCode.OK)) + + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + assert span.name == "agent-span-name" + assert span.status.is_ok + assert not span.events + attributes = dict(span.attributes or {}) + assert attributes.pop(OPENINFERENCE_SPAN_KIND) == AGENT + assert attributes.pop(INPUT_MIME_TYPE) == TEXT + assert attributes.pop(INPUT_VALUE) == "input" + assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT + assert attributes.pop(OUTPUT_VALUE) == "output" + assert not attributes + def test_tool( self, in_memory_span_exporter: InMemorySpanExporter, @@ -714,6 +741,58 @@ def decorated_chain_with_session(input: str) -> str: assert attributes[SESSION_ID] == "123" +class TestAgentDecorator: + def test_plain_text_input_and_output( + self, + in_memory_span_exporter: InMemorySpanExporter, + tracer: OITracer, + ) -> None: + @tracer.agent + def decorated_agent(input: str) -> str: + return "output" + + decorated_agent("input") + + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + assert span.name == "decorated_agent" + assert span.status.is_ok + assert not span.events + attributes = dict(span.attributes or {}) + assert attributes.pop(OPENINFERENCE_SPAN_KIND) == AGENT + assert attributes.pop(INPUT_MIME_TYPE) == TEXT + assert attributes.pop(INPUT_VALUE) == "input" + assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT + assert attributes.pop(OUTPUT_VALUE) == "output" + assert not attributes + + async def test_async_with_overridden_name( + self, + in_memory_span_exporter: InMemorySpanExporter, + tracer: OITracer, + ) -> None: + @tracer.agent(name="custom-name") + async def decorated_agent(input: str) -> str: + return "output" + + await decorated_agent("input") + + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + assert span.name == "custom-name" + assert span.status.is_ok + assert not span.events + attributes = dict(span.attributes or {}) + assert attributes.pop(OPENINFERENCE_SPAN_KIND) == AGENT + assert attributes.pop(INPUT_MIME_TYPE) == TEXT + assert attributes.pop(INPUT_VALUE) == "input" + assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT + assert attributes.pop(OUTPUT_VALUE) == "output" + assert not attributes + + class TestTracerToolDecorator: def test_tool_with_one_argument_and_docstring( self, @@ -970,6 +1049,7 @@ def tool_function(input_str: str) -> str: JSON = OpenInferenceMimeTypeValues.JSON.value # span kinds +AGENT = OpenInferenceSpanKindValues.AGENT.value CHAIN = OpenInferenceSpanKindValues.CHAIN.value TOOL = OpenInferenceSpanKindValues.TOOL.value