Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Jan 11, 2024
1 parent 76a3777 commit 9029fa8
Show file tree
Hide file tree
Showing 15 changed files with 662 additions and 166 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
"""
Phoenix collector should be running in the background.
"""
import asyncio

import openai
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
"""
Phoenix collector should be running in the background.
"""
import contextvars
import inspect
import logging
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
"""
Phoenix collector should be running in the background.
"""
import asyncio
import inspect
import logging
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
"""
Phoenix collector should be running in the background.
"""
import asyncio
import inspect
import logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ dependencies = [
]

[project.optional-dependencies]
instruments = [
"openai >= 1.0.0",
]
test = [
"openai == 1.0.0",
"opentelemetry-sdk",
"respx",
]

[project.urls]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from importlib import import_module
from typing import Any, Collection

from openinference.instrumentation.openai._request import (
Expand All @@ -19,34 +20,36 @@

class OpenAIInstrumentor(BaseInstrumentor): # type: ignore
"""
An instrumentor for openai.OpenAI.request and openai.AsyncOpenAI.request
An instrumentor for openai
"""

__slots__ = (
"_original_request",
"_original_async_request",
)

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs: Any) -> None:
if (include_extra_attributes := kwargs.get("include_extra_attributes")) is None:
include_extra_attributes = True
if not (tracer_provider := kwargs.get("tracer_provider")):
tracer_provider = trace_api.get_tracer_provider()
tracer = trace_api.get_tracer(__name__, __version__, tracer_provider)
openai = import_module(_MODULE)
self._original_request = openai.OpenAI.request
self._original_async_request = openai.AsyncOpenAI.request
wrap_function_wrapper(
module=_MODULE,
name="OpenAI.request",
wrapper=_Request(
tracer=tracer,
include_extra_attributes=include_extra_attributes,
),
wrapper=_Request(tracer=tracer),
)
wrap_function_wrapper(
module=_MODULE,
name="AsyncOpenAI.request",
wrapper=_AsyncRequest(
tracer=tracer,
include_extra_attributes=include_extra_attributes,
),
wrapper=_AsyncRequest(tracer=tracer),
)

def _uninstrument(self, **kwargs: Any) -> None:
pass
openai = import_module(_MODULE)
openai.OpenAI.request = self._original_request
openai.AsyncOpenAI.request = self._original_async_request
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@

def _get_extra_attributes_from_request(
cast_to: type,
request_options: Mapping[str, Any],
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
if not isinstance(request_options, Mapping):
if not isinstance(request_parameters, Mapping):
return
if cast_to is ChatCompletion:
yield from _get_attributes_from_chat_completion_create_param(request_options)
yield from _get_attributes_from_chat_completion_create_param(request_parameters)
elif cast_to is CreateEmbeddingResponse:
yield from _get_attributes_from_embedding_create_param(request_options)
yield from _get_attributes_from_embedding_create_param(request_parameters)
elif cast_to is Completion:
yield from _get_attributes_from_completion_create_param(request_options)
yield from _get_attributes_from_completion_create_param(request_parameters)
else:
try:
yield SpanAttributes.LLM_INVOCATION_PARAMETERS, json.dumps(request_options)
yield SpanAttributes.LLM_INVOCATION_PARAMETERS, json.dumps(request_parameters)
except Exception:
logger.exception("Failed to serialize request options")

Expand All @@ -55,7 +55,9 @@ def _get_attributes_from_chat_completion_create_param(
invocation_params.pop("tools", None)
yield SpanAttributes.LLM_INVOCATION_PARAMETERS, json.dumps(invocation_params)
if (input_messages := params.get("messages")) and isinstance(input_messages, Iterable):
for index, input_message in enumerate(input_messages):
# Use reversed() to get the last message first. This is because OTEL has a default limit of
# 128 attributes per span, and flattening increases the number of attributes very quickly.
for index, input_message in reversed(list(enumerate(input_messages))):
for key, value in _get_attributes_from_message_param(input_message):
yield f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{key}", value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from functools import singledispatch
from importlib import import_module
from types import MappingProxyType, ModuleType
from types import ModuleType
from typing import (
Any,
Iterable,
Expand Down Expand Up @@ -43,16 +43,16 @@
@singledispatch
def _get_extra_attributes_from_response(
response: Any,
request_options: Mapping[str, Any] = MappingProxyType({}),
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
# this is a fallback (for singledispatch)
# this is a fallback for @singledispatch
yield from ()


@_get_extra_attributes_from_response.register
def _(
completion: ChatCompletion,
request_options: Mapping[str, Any] = MappingProxyType({}),
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
# See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion.py#L40 # noqa: E501
if model := getattr(completion, "model", None):
Expand All @@ -71,14 +71,14 @@ def _(
@_get_extra_attributes_from_response.register
def _(
completion: Completion,
request_options: Mapping[str, Any] = MappingProxyType({}),
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
# See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/completion.py#L13 # noqa: E501
if model := getattr(completion, "model", None):
yield SpanAttributes.LLM_MODEL_NAME, model
if usage := getattr(completion, "usage", None):
yield from _get_attributes_from_completion_usage(usage)
if model_prompt := request_options.get("prompt"):
if model_prompt := request_parameters.get("prompt"):
# prompt: Required[Union[str, List[str], List[int], List[List[int]], None]]
# See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/completion_create_params.py#L38 # noqa: E501
# FIXME: tokens (List[int], List[List[int]]) can't be decoded reliably because model
Expand All @@ -90,7 +90,7 @@ def _(
@_get_extra_attributes_from_response.register
def _(
response: CreateEmbeddingResponse,
request_options: Mapping[str, Any] = MappingProxyType({}),
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
# See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/create_embedding_response.py#L20 # noqa: E501
if usage := getattr(response, "usage", None):
Expand All @@ -104,7 +104,7 @@ def _(
continue
for key, value in _get_attributes_from_embedding(embedding):
yield f"{SpanAttributes.EMBEDDING_EMBEDDINGS}.{index}.{key}", value
embedding_input = request_options.get("input")
embedding_input = request_parameters.get("input")
for index, text in enumerate(_get_texts(embedding_input, model)):
# input: Required[Union[str, List[str], List[int], List[List[int]]]]
# See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/embedding_create_params.py#L12 # noqa: E501
Expand Down
Loading

0 comments on commit 9029fa8

Please sign in to comment.