Skip to content

Commit

Permalink
dry out tool wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Jan 19, 2025
1 parent 949defd commit 9346ff5
Showing 1 changed file with 73 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -934,40 +934,17 @@ def sync_wrapper(
kwargs: Dict[str, Any],
) -> ReturnType:
tracer = self
span_name = name or wrapped.__name__
bound_args = inspect.signature(wrapped).bind(*args, **kwargs)
bound_args.apply_defaults()
arguments = bound_args.arguments
input_attributes = get_input_value_and_mime_type(value=arguments)
tool_description: Optional[str] = description
if (
not tool_description
and (docstring := wrapped.__doc__) is not None
and (stripped_docstring := docstring.strip())
):
tool_description = stripped_docstring
tool_attributes = get_tool_attributes(
name=name or wrapped.__name__,
description=tool_description,
parameters={},
)
with tracer.start_as_current_span(
span_name,
openinference_span_kind=OpenInferenceSpanKindValues.TOOL,
attributes={
**input_attributes,
**tool_attributes,
},
) as span:
with _tool_context(
tracer=tracer,
name=name,
description=description,
wrapped=wrapped,
instance=instance,
args=args,
kwargs=kwargs,
) as tool_context:
output = wrapped(*args, **kwargs)
span.set_status(Status(StatusCode.OK))
attributes = getattr(
span, "attributes", {}
) # INVALID_SPAN does not have the attributes property
has_output = OUTPUT_VALUE in attributes
if not has_output:
span.set_output(value=output)
return output
return tool_context.process_output(output)

@wrapt.decorator # type: ignore[misc]
async def async_wrapper(
Expand All @@ -977,42 +954,17 @@ async def async_wrapper(
kwargs: Dict[str, Any],
) -> ReturnType:
tracer = self
span_name = name or wrapped.__name__
bound_args = inspect.signature(wrapped).bind(*args, **kwargs)
bound_args.apply_defaults()
arguments = bound_args.arguments
input_attributes = get_input_value_and_mime_type(value=arguments)
tool_description: Optional[str] = description
if (
not tool_description
and (docstring := wrapped.__doc__) is not None
and (stripped_docstring := docstring.strip())
):
tool_description = stripped_docstring
tool_attributes = get_tool_attributes(
name=name or wrapped.__name__,
description=tool_description,
parameters={},
)
with tracer.start_as_current_span(
span_name,
openinference_span_kind=OpenInferenceSpanKindValues.TOOL,
attributes={
**input_attributes,
**tool_attributes,
},
) as span:
with _tool_context(
tracer=tracer,
name=name,
description=description,
wrapped=wrapped,
instance=instance,
args=args,
kwargs=kwargs,
) as tool_context:
output = await wrapped(*args, **kwargs)
span.set_status(Status(StatusCode.OK))
attributes = getattr(
span, "attributes", {}
) # INVALID_SPAN does not have the attributes property
has_output = OUTPUT_VALUE in attributes
if (
not has_output
): # don't overwrite if the output is set inside the wrapped function
span.set_output(value=output)
return output
return tool_context.process_output(output)

if wrapped_function is not None:
if asyncio.iscoroutinefunction(wrapped_function):
Expand Down Expand Up @@ -1123,6 +1075,59 @@ def _chain_context(
span.set_status(Status(StatusCode.OK))


class _ToolContext:
def __init__(self, span: OpenInferenceSpan) -> None:
self._span = span

def process_output(self, output: Any) -> Any:
attributes = getattr(self._span, "attributes", {})
has_output = OUTPUT_VALUE in attributes
if not has_output:
self._span.set_output(value=output)
return output


@contextmanager
def _tool_context(
*,
tracer: "OITracer",
name: Optional[str],
description: Optional[str],
wrapped: Callable[ParametersType, Union[ReturnType, Awaitable[ReturnType]]],
instance: Any,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Iterator[_ToolContext]:
span_name = name or wrapped.__name__
bound_args = inspect.signature(wrapped).bind(*args, **kwargs)
bound_args.apply_defaults()
arguments = bound_args.arguments
input_attributes = get_input_value_and_mime_type(value=arguments)
tool_description: Optional[str] = description
if (
not tool_description
and (docstring := wrapped.__doc__) is not None
and (stripped_docstring := docstring.strip())
):
tool_description = stripped_docstring
tool_attributes = get_tool_attributes(
name=name or wrapped.__name__,
description=tool_description,
parameters={},
)
with tracer.start_as_current_span(
span_name,
openinference_span_kind=OpenInferenceSpanKindValues.TOOL,
attributes={
**input_attributes,
**tool_attributes,
},
) as span:
context = _ToolContext(span=span)
yield context
span.set_status(Status(StatusCode.OK))


# span attributes
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
INPUT_VALUE = SpanAttributes.INPUT_VALUE
Expand Down

0 comments on commit 9346ff5

Please sign in to comment.