diff --git a/python/openinference-instrumentation/examples/tracer.ipynb b/python/openinference-instrumentation/examples/tracer.ipynb index b44dd00eb..0dbfca901 100644 --- a/python/openinference-instrumentation/examples/tracer.ipynb +++ b/python/openinference-instrumentation/examples/tracer.ipynb @@ -11,14 +11,13 @@ "\n", "from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter\n", "from opentelemetry.sdk.resources import Resource\n", - "from opentelemetry.sdk.trace import TracerProvider\n", "from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor\n", "from opentelemetry.trace import Status, StatusCode, set_tracer_provider\n", "\n", - "from openinference.instrumentation import TraceConfig\n", "from openinference.instrumentation.config import (\n", - " OITracer,\n", + " OpenInferenceTracerProvider,\n", " chain,\n", + " get_current_span,\n", " get_input_value_and_mime_type,\n", " get_openinference_span_kind,\n", " get_output_value_and_mime_type,\n", @@ -34,11 +33,11 @@ "source": [ "endpoint = \"http://127.0.0.1:6006/v1/traces\"\n", "resource = Resource(attributes={ResourceAttributes.PROJECT_NAME: \"openinference-tracer\"})\n", - "tracer_provider = TracerProvider(resource=resource)\n", + "tracer_provider = OpenInferenceTracerProvider(resource=resource)\n", "tracer_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))\n", "tracer_provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter()))\n", "set_tracer_provider(tracer_provider)\n", - "tracer = OITracer(wrapped=tracer_provider.get_tracer(__name__), config=TraceConfig())" + "tracer = tracer_provider.get_tracer(__name__)" ] }, { @@ -293,6 +292,23 @@ "chain_runner = ChainRunner()\n", "chain_runner.decorated_chain_method(\"input\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@chain\n", + "def decorated_chain_with_input_and_output_set_inside_the_wrapped_function(input: str) -> str:\n", + " span = get_current_span()\n", + " span.set_input(\"overridden-input\")\n", + " span.set_output(\"overridden-output\")\n", + " return \"output\"\n", + "\n", + "\n", + "decorated_chain_with_input_and_output_set_inside_the_wrapped_function(\"input\")" + ] } ], "metadata": { diff --git a/python/openinference-instrumentation/src/openinference/instrumentation/config.py b/python/openinference-instrumentation/src/openinference/instrumentation/config.py index 7afccc6d4..fdf836d7a 100644 --- a/python/openinference-instrumentation/src/openinference/instrumentation/config.py +++ b/python/openinference-instrumentation/src/openinference/instrumentation/config.py @@ -44,6 +44,9 @@ get_tracer, use_span, ) +from opentelemetry.trace import ( + get_current_span as otel_get_current_span, +) from opentelemetry.util.types import Attributes, AttributeValue from typing_extensions import ParamSpec, TypeAlias, overload @@ -464,6 +467,10 @@ def sync_wrapper( ), ) as span: output = wrapped(*args, **kwargs) + span.set_status(Status(StatusCode.OK)) + has_output = OUTPUT_MIME_TYPE in span.attributes or OUTPUT_VALUE in span.attributes + if has_output: + return output # don't overwrite if the output is set inside the wrapped function if isinstance(output, (str, int, float, bool)): span.set_output( output, @@ -474,7 +481,6 @@ def sync_wrapper( safe_json_dumps(output), mime_type=OpenInferenceMimeTypeValues.JSON, ) - span.set_status(Status(StatusCode.OK)) return output @wrapt.decorator # type: ignore[misc] @@ -499,6 +505,10 @@ async def async_wrapper( ), ) as span: output = await wrapped(*args, **kwargs) + span.set_status(Status(StatusCode.OK)) + has_output = OUTPUT_MIME_TYPE in span.attributes or OUTPUT_VALUE in span.attributes + if has_output: + return output # don't overwrite if the output is set inside the wrapped function if isinstance(output, (str, int, float, bool)): span.set_output( output, @@ -509,7 +519,6 @@ async def async_wrapper( safe_json_dumps(output), mime_type=OpenInferenceMimeTypeValues.JSON, ) - span.set_status(Status(StatusCode.OK)) return output if wrapped_function is not None: @@ -565,6 +574,10 @@ def set_output( self.set_attributes(get_output_value_and_mime_type(value, mime_type)) +def get_current_span(context: Optional[Context] = None) -> OpenInferenceSpan: + return OpenInferenceSpan(otel_get_current_span(context), TraceConfig()) + + class ChainSpan(OpenInferenceSpan): def __init__(self, wrapped: Span, config: TraceConfig) -> None: super().__init__(wrapped, config)