diff --git a/python/openinference-instrumentation/src/openinference/instrumentation/config.py b/python/openinference-instrumentation/src/openinference/instrumentation/config.py index 421d498a1..f34756d8a 100644 --- a/python/openinference-instrumentation/src/openinference/instrumentation/config.py +++ b/python/openinference-instrumentation/src/openinference/instrumentation/config.py @@ -934,40 +934,17 @@ def sync_wrapper( kwargs: Dict[str, Any], ) -> ReturnType: tracer = self - span_name = name or wrapped.__name__ - bound_args = inspect.signature(wrapped).bind(*args, **kwargs) - bound_args.apply_defaults() - arguments = bound_args.arguments - input_attributes = get_input_value_and_mime_type(value=arguments) - tool_description: Optional[str] = description - if ( - not tool_description - and (docstring := wrapped.__doc__) is not None - and (stripped_docstring := docstring.strip()) - ): - tool_description = stripped_docstring - tool_attributes = get_tool_attributes( - name=name or wrapped.__name__, - description=tool_description, - parameters={}, - ) - with tracer.start_as_current_span( - span_name, - openinference_span_kind=OpenInferenceSpanKindValues.TOOL, - attributes={ - **input_attributes, - **tool_attributes, - }, - ) as span: + with _tool_context( + tracer=tracer, + name=name, + description=description, + wrapped=wrapped, + instance=instance, + args=args, + kwargs=kwargs, + ) as tool_context: output = wrapped(*args, **kwargs) - span.set_status(Status(StatusCode.OK)) - attributes = getattr( - span, "attributes", {} - ) # INVALID_SPAN does not have the attributes property - has_output = OUTPUT_VALUE in attributes - if not has_output: - span.set_output(value=output) - return output + return tool_context.process_output(output) @wrapt.decorator # type: ignore[misc] async def async_wrapper( @@ -977,42 +954,17 @@ async def async_wrapper( kwargs: Dict[str, Any], ) -> ReturnType: tracer = self - span_name = name or wrapped.__name__ - bound_args = inspect.signature(wrapped).bind(*args, **kwargs) - bound_args.apply_defaults() - arguments = bound_args.arguments - input_attributes = get_input_value_and_mime_type(value=arguments) - tool_description: Optional[str] = description - if ( - not tool_description - and (docstring := wrapped.__doc__) is not None - and (stripped_docstring := docstring.strip()) - ): - tool_description = stripped_docstring - tool_attributes = get_tool_attributes( - name=name or wrapped.__name__, - description=tool_description, - parameters={}, - ) - with tracer.start_as_current_span( - span_name, - openinference_span_kind=OpenInferenceSpanKindValues.TOOL, - attributes={ - **input_attributes, - **tool_attributes, - }, - ) as span: + with _tool_context( + tracer=tracer, + name=name, + description=description, + wrapped=wrapped, + instance=instance, + args=args, + kwargs=kwargs, + ) as tool_context: output = await wrapped(*args, **kwargs) - span.set_status(Status(StatusCode.OK)) - attributes = getattr( - span, "attributes", {} - ) # INVALID_SPAN does not have the attributes property - has_output = OUTPUT_VALUE in attributes - if ( - not has_output - ): # don't overwrite if the output is set inside the wrapped function - span.set_output(value=output) - return output + return tool_context.process_output(output) if wrapped_function is not None: if asyncio.iscoroutinefunction(wrapped_function): @@ -1123,6 +1075,59 @@ def _chain_context( span.set_status(Status(StatusCode.OK)) +class _ToolContext: + def __init__(self, span: OpenInferenceSpan) -> None: + self._span = span + + def process_output(self, output: Any) -> Any: + attributes = getattr(self._span, "attributes", {}) + has_output = OUTPUT_VALUE in attributes + if not has_output: + self._span.set_output(value=output) + return output + + +@contextmanager +def _tool_context( + *, + tracer: "OITracer", + name: Optional[str], + description: Optional[str], + wrapped: Callable[ParametersType, Union[ReturnType, Awaitable[ReturnType]]], + instance: Any, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +) -> Iterator[_ToolContext]: + span_name = name or wrapped.__name__ + bound_args = inspect.signature(wrapped).bind(*args, **kwargs) + bound_args.apply_defaults() + arguments = bound_args.arguments + input_attributes = get_input_value_and_mime_type(value=arguments) + tool_description: Optional[str] = description + if ( + not tool_description + and (docstring := wrapped.__doc__) is not None + and (stripped_docstring := docstring.strip()) + ): + tool_description = stripped_docstring + tool_attributes = get_tool_attributes( + name=name or wrapped.__name__, + description=tool_description, + parameters={}, + ) + with tracer.start_as_current_span( + span_name, + openinference_span_kind=OpenInferenceSpanKindValues.TOOL, + attributes={ + **input_attributes, + **tool_attributes, + }, + ) as span: + context = _ToolContext(span=span) + yield context + span.set_status(Status(StatusCode.OK)) + + # span attributes INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE INPUT_VALUE = SpanAttributes.INPUT_VALUE