Skip to content

Commit

Permalink
tool decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Jan 17, 2025
1 parent 426bf15 commit 6487179
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 4 deletions.
77 changes: 76 additions & 1 deletion python/openinference-instrumentation/examples/tracer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -26,6 +26,7 @@
" get_output_value_and_mime_type,\n",
" get_tool_attributes,\n",
" suppress_tracing,\n",
" tool,\n",
")\n",
"from openinference.semconv.resource import ResourceAttributes"
]
Expand Down Expand Up @@ -467,6 +468,80 @@
" )\n",
" span.set_status(Status(StatusCode.OK))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@tool\n",
"def decorated_tool(input1: str, input2: int) -> None:\n",
" \"\"\"\n",
" tool-description\n",
" \"\"\"\n",
"\n",
"\n",
"decorated_tool(\"input1\", 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@tool\n",
"async def decorated_tool_async(input1: str, input2: int) -> None:\n",
" \"\"\"\n",
" tool-description\n",
" \"\"\"\n",
"\n",
"\n",
"await decorated_tool_async(\"input1\", 1) # type: ignore[top-level-await]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@tool(\n",
" name=\"decorated-tool-with-overriden-name\",\n",
" description=\"overriden-tool-description\",\n",
")\n",
"def this_tool_name_should_be_overriden(input1: str, input2: int) -> None:\n",
" \"\"\"\n",
" this tool description should be overriden\n",
" \"\"\"\n",
"\n",
"\n",
"this_tool_name_should_be_overriden(\"input1\", 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@tool\n",
"def tool_with_changes_inside_the_wrapped_function(input1: str, input2: int) -> str:\n",
" span = get_current_openinference_span()\n",
" print(type(span))\n",
" span.set_input(\"inside-input\")\n",
" span.set_output(\"inside-output\")\n",
" span.set_tool(\n",
" name=\"inside-tool-name\",\n",
" description=\"inside-tool-description\",\n",
" parameters={\"inside-input\": \"inside-input\"},\n",
" )\n",
" return \"output\"\n",
"\n",
"\n",
"tool_with_changes_inside_the_wrapped_function(\"input1\", 1)"
]
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,134 @@ async def async_wrapper(
return lambda x: sync_wrapper(x)


# overload for @tool usage (no parameters)
@overload
def tool(
wrapped_function: Callable[ParametersType, ReturnType],
/,
*,
name: None = None,
description: Optional[str] = None,
) -> Callable[ParametersType, ReturnType]: ...


# overload for @tool(name="name") usage (with parameters)
@overload
def tool(
wrapped_function: None = None,
/,
*,
name: Optional[str] = None,
description: Optional[str] = None,
) -> Callable[[Callable[ParametersType, ReturnType]], Callable[ParametersType, ReturnType]]: ...


def tool(
wrapped_function: Optional[Callable[ParametersType, ReturnType]] = None,
/,
*,
name: Optional[str] = None,
description: Optional[str] = None,
) -> Union[
Callable[ParametersType, ReturnType],
Callable[[Callable[ParametersType, ReturnType]], Callable[ParametersType, ReturnType]],
]:
@wrapt.decorator # type: ignore[misc]
def sync_wrapper(
wrapped: Callable[ParametersType, ReturnType],
instance: Any,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> ReturnType:
tracer = OITracer(get_tracer(__name__), config=TraceConfig())
span_name = name or wrapped.__name__
bound_args = inspect.signature(wrapped).bind(*args, **kwargs)
bound_args.apply_defaults()
arguments = bound_args.arguments

if len(arguments) == 1:
argument = next(iter(arguments.values()))
input_attributes = get_input_value_and_mime_type(value=argument)
else:
input_attributes = get_input_value_and_mime_type(value=arguments)
tool_parameters = safe_json_dumps_io_value(arguments)
tool_attributes = get_tool_attributes(
name=name or wrapped.__name__,
description=description or wrapped.__doc__,
parameters=tool_parameters,
)
with tracer.start_as_current_span(
span_name,
openinference_span_kind=OpenInferenceSpanKindValues.TOOL,
attributes={
**input_attributes,
**tool_attributes,
},
) as span:
output = wrapped(*args, **kwargs)
span.set_status(Status(StatusCode.OK))
attributes = getattr(
span, "attributes", {}
) # INVALID_SPAN does not have the attributes property
has_output = OUTPUT_VALUE in attributes
if not has_output:
span.set_output(value=output)
return output

@wrapt.decorator # type: ignore[misc]
async def async_wrapper(
wrapped: Callable[ParametersType, Awaitable[ReturnType]],
instance: Any,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> ReturnType:
tracer = OITracer(get_tracer(__name__), config=TraceConfig())
span_name = name or wrapped.__name__
bound_args = inspect.signature(wrapped).bind(*args, **kwargs)
bound_args.apply_defaults()
arguments = bound_args.arguments

if len(arguments) == 1:
argument = next(iter(arguments.values()))
input_attributes = get_input_value_and_mime_type(value=argument)
else:
input_attributes = get_input_value_and_mime_type(value=arguments)
tool_parameters = safe_json_dumps_io_value(arguments)
tool_description: Optional[str] = None
if (docstring := wrapped.__doc__) is not None:
tool_description = docstring.strip()
tool_attributes = get_tool_attributes(
name=name or wrapped.__name__,
description=tool_description,
parameters=tool_parameters,
)
with tracer.start_as_current_span(
span_name,
openinference_span_kind=OpenInferenceSpanKindValues.TOOL,
attributes={
**input_attributes,
**tool_attributes,
},
) as span:
output = await wrapped(*args, **kwargs)
span.set_status(Status(StatusCode.OK))
attributes = getattr(
span, "attributes", {}
) # INVALID_SPAN does not have the attributes property
has_output = OUTPUT_VALUE in attributes
if not has_output: # don't overwrite if the output is set inside the wrapped function
span.set_output(value=output)
return output

if wrapped_function is not None:
if asyncio.iscoroutinefunction(wrapped_function):
return async_wrapper(wrapped_function) # type: ignore[no-any-return]
return sync_wrapper(wrapped_function) # type: ignore[no-any-return]
if asyncio.iscoroutinefunction(wrapped_function):
return lambda x: async_wrapper(x)
return lambda x: sync_wrapper(x)


class OpenInferenceSpan(wrapt.ObjectProxy): # type: ignore[misc]
def __init__(self, wrapped: Span, config: TraceConfig) -> None:
super().__init__(wrapped)
Expand Down Expand Up @@ -700,23 +828,23 @@ def set_tool(
def get_current_openinference_span(
context: Optional[Context] = None,
*,
kind: Literal["chain"],
kind: Literal["chain"] = "chain",
) -> ChainSpan: ...


@overload
def get_current_openinference_span(
context: Optional[Context] = None,
*,
kind: Literal["tool"],
kind: Literal["tool"] = "tool",
) -> ToolSpan: ...


@overload
def get_current_openinference_span(
context: Optional[Context] = None,
*,
kind: None,
kind: None = None,
) -> OpenInferenceSpan: ...


Expand Down Expand Up @@ -759,6 +887,57 @@ def __init__(self, wrapped: Tracer, config: TraceConfig) -> None:
def id_generator(self) -> IdGenerator:
return self._self_id_generator

# @contextmanager
# @overload
# def start_as_current_span(
# self,
# name: str,
# context: Optional[Context] = None,
# kind: SpanKind = SpanKind.INTERNAL,
# attributes: Attributes = None,
# links: Optional["Sequence[Link]"] = (),
# start_time: Optional[int] = None,
# record_exception: bool = True,
# set_status_on_exception: bool = True,
# end_on_exit: bool = True,
# *,
# openinference_span_kind: Literal["chain"],
# ) -> Iterator[ChainSpan]: ...

# @contextmanager
# @overload
# def start_as_current_span(
# self,
# name: str,
# context: Optional[Context] = None,
# kind: SpanKind = SpanKind.INTERNAL,
# attributes: Attributes = None,
# links: Optional["Sequence[Link]"] = (),
# start_time: Optional[int] = None,
# record_exception: bool = True,
# set_status_on_exception: bool = True,
# end_on_exit: bool = True,
# *,
# openinference_span_kind: Literal["tool"],
# ) -> Iterator[ToolSpan]: ...

# @contextmanager
# @overload
# def start_as_current_span(
# self,
# name: str,
# context: Optional[Context] = None,
# kind: SpanKind = SpanKind.INTERNAL,
# attributes: Attributes = None,
# links: Optional["Sequence[Link]"] = (),
# start_time: Optional[int] = None,
# record_exception: bool = True,
# set_status_on_exception: bool = True,
# end_on_exit: bool = True,
# *,
# openinference_span_kind: Optional[OpenInferenceSpanKind] = None,
# ) -> Iterator[OpenInferenceSpan]: ...

@contextmanager
def start_as_current_span(
self,
Expand Down

0 comments on commit 6487179

Please sign in to comment.