Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(groq): refactor groq, add groq tool call support #1133

Merged
merged 7 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import asyncio
import os

from groq import AsyncGroq, Groq
from groq.types.chat import ChatCompletionToolMessageParam
from phoenix.otel import register
cjunkin marked this conversation as resolved.
Show resolved Hide resolved

from openinference.instrumentation.groq import GroqInstrumentor


def test():
client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)

weather_function = {
"type": "function",
"function": {
"name": "get_weather",
"description": "finds the weather for a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'London'",
}
},
"required": ["city"],
},
},
}

sys_prompt = "Respond to the user's query using the correct tool."
user_msg = "What's the weather like in San Francisco?"

messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_msg}]
response = client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
temperature=0.0,
tools=[weather_function],
tool_choice="required",
)

message = response.choices[0].message
assert (tool_calls := message.tool_calls)
tool_call_id = tool_calls[0].id
messages.append(message)
messages.append(
ChatCompletionToolMessageParam(content="sunny", role="tool", tool_call_id=tool_call_id),
)
final_response = client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
)
return final_response


async def async_test():
client = AsyncGroq(
api_key=os.environ.get("GROQ_API_KEY"),
)

weather_function = {
"type": "function",
"function": {
"name": "get_weather",
"description": "finds the weather for a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'London'",
}
},
"required": ["city"],
},
},
}

sys_prompt = "Respond to the user's query using the correct tool."
user_msg = "What's the weather like in San Francisco?"

messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_msg}]
response = await client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
temperature=0.0,
tools=[weather_function],
tool_choice="required",
)

message = response.choices[0].message
assert (tool_calls := message.tool_calls)
tool_call_id = tool_calls[0].id
messages.append(message)
messages.append(
ChatCompletionToolMessageParam(content="sunny", role="tool", tool_call_id=tool_call_id),
)
final_response = await client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
)
return final_response


if __name__ == "__main__":
tracer_provider = register(project_name="groq_debug")
GroqInstrumentor().instrument(tracer_provider=tracer_provider)

response = test()
print("Response\n--------")
print(response)

async_response = asyncio.run(async_test())
print("\nAsync Response\n--------")
print(async_response)
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from typing import Any, Collection

from opentelemetry import trace as trace_api
from opentelemetry.instrumentation.instrumentor import ( # type: ignore[attr-defined]
BaseInstrumentor,
)
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore
from wrapt import wrap_function_wrapper

from groq.resources.chat.completions import AsyncCompletions, Completions
Expand All @@ -17,6 +15,7 @@
from openinference.instrumentation.groq.version import __version__

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

_instruments = ("groq >= 0.9.0",)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
import logging
from typing import (
Any,
Iterable,
Iterator,
Mapping,
Tuple,
TypeVar,
)
from enum import Enum
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Tuple, TypeVar

from opentelemetry.util.types import AttributeValue

from groq.types.chat import ChatCompletionMessage, ChatCompletionMessageToolCall
from groq.types.chat.chat_completion_message_tool_call import Function
from openinference.instrumentation import safe_json_dumps
from openinference.instrumentation.groq._utils import _as_input_attributes, _io_value_and_type
from openinference.semconv.trace import (
MessageAttributes,
OpenInferenceSpanKindValues,
SpanAttributes,
ToolCallAttributes,
)

__all__ = ("_RequestAttributesExtractor",)
Expand Down Expand Up @@ -49,27 +46,115 @@ def get_extra_attributes_from_request(
if not isinstance(request_parameters, Mapping):
return
invocation_params = dict(request_parameters)
invocation_params.pop("messages", None)
invocation_params.pop("messages", None) # Remove LLM input messages
invocation_params.pop("functions", None)
invocation_params.pop("tools", None)

if isinstance((tools := invocation_params.pop("tools", None)), Iterable):
for i, tool in enumerate(tools):
yield f"llm.tools.{i}.tool.json_schema", safe_json_dumps(tool)

yield SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_params)

if (input_messages := request_parameters.get("messages")) and isinstance(
input_messages, Iterable
):
# Use reversed() to get the last message first. This is because OTEL has a default
# limit of 128 attributes per span, and flattening increases the number of
# attributes very quickly.
for index, input_message in reversed(list(enumerate(input_messages))):
if role := input_message.get("role"):
yield (
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{MessageAttributes.MESSAGE_ROLE}",
role,
)
if content := input_message.get("content"):
# Use reversed() to get the last message first. This is because OTEL has a default
# limit of 128 attributes per span, and flattening increases the number of
# attributes very quickly.
for key, value in self._get_attributes_from_message_param(input_message):
yield f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{key}", value

def _get_attributes_from_message_param(
self,
message: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
if not hasattr(message, "get"):
if isinstance(message, ChatCompletionMessage):
message = self._cast_chat_completion_to_mapping(message)
else:
return
if role := message.get("role"):
yield (
MessageAttributes.MESSAGE_ROLE,
role.value if isinstance(role, Enum) else role,
)

if content := message.get("content"):
yield (
MessageAttributes.MESSAGE_CONTENT,
content,
)

if name := message.get("name"):
yield MessageAttributes.MESSAGE_NAME, name

if tool_call_id := message.get("tool_call_id"):
yield MessageAttributes.MESSAGE_TOOL_CALL_ID, tool_call_id

# Deprecated by Groq
if (function_call := message.get("function_call")) and hasattr(function_call, "get"):
if function_name := function_call.get("name"):
yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, function_name
if function_arguments := function_call.get("arguments"):
yield (
MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON,
function_arguments,
)

if (tool_calls := message.get("tool_calls"),) and isinstance(tool_calls, Iterable):
cjunkin marked this conversation as resolved.
Show resolved Hide resolved
for index, tool_call in enumerate(tool_calls):
if not hasattr(tool_call, "get"):
continue
if (tool_call_id := tool_call.get("id")) is not None:
yield (
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{MessageAttributes.MESSAGE_CONTENT}",
content,
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_ID}",
tool_call_id,
)
if (function := tool_call.get("function")) and hasattr(function, "get"):
if name := function.get("name"):
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
name,
)
if arguments := function.get("arguments"):
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
arguments,
)

def _cast_chat_completion_to_mapping(self, message: ChatCompletionMessage) -> Mapping[str, Any]:
try:
casted_message = dict(message)
if (tool_calls := casted_message.get("tool_calls")) and isinstance(
cjunkin marked this conversation as resolved.
Show resolved Hide resolved
tool_calls, Iterable
):
casted_tool_calls: List[Dict[str, Any]] = []
for tool_call in tool_calls:
if isinstance(tool_call, ChatCompletionMessageToolCall):
tool_call_dict = dict(tool_call)

if (function := tool_call_dict.get("function")) and isinstance(
function, Function
):
tool_call_dict["function"] = dict(function)

casted_tool_calls.append(tool_call_dict)
else:
logger.debug(f"Skipping tool_call of unexpected type: {type(tool_call)}")

casted_message["tool_calls"] = casted_tool_calls

return casted_message

except Exception as e:
logger.exception(
f"Failed to convert ChatCompletionMessage to mapping for {message}: {e}"
)
return {}


T = TypeVar("T", bound=type)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
from typing import Any, Iterable, Iterator, Mapping, Tuple

from opentelemetry.util.types import AttributeValue

from openinference.instrumentation.groq._utils import _as_output_attributes, _io_value_and_type
from openinference.semconv.trace import MessageAttributes, SpanAttributes, ToolCallAttributes

__all__ = ("_ResponseAttributesExtractor",)

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


class _ResponseAttributesExtractor:
__slots__ = ()
cjunkin marked this conversation as resolved.
Show resolved Hide resolved

def get_attributes(self, response: Any) -> Iterator[Tuple[str, AttributeValue]]:
yield from _as_output_attributes(
_io_value_and_type(response),
)

def get_extra_attributes(
self,
response: Any,
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
yield from self._get_attributes_from_chat_completion(
completion=response,
request_parameters=request_parameters,
)

def _get_attributes_from_chat_completion(
self,
completion: Any,
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
if model := getattr(completion, "model", None):
yield SpanAttributes.LLM_MODEL_NAME, model
if usage := getattr(completion, "usage", None):
yield from self._get_attributes_from_completion_usage(usage)
if (choices := getattr(completion, "choices", None)) and isinstance(choices, Iterable):
for choice in choices:
if (index := getattr(choice, "index", None)) is None:
continue
if message := getattr(choice, "message", None):
for key, value in self._get_attributes_from_chat_completion_message(message):
yield f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{index}.{key}", value

def _get_attributes_from_chat_completion_message(
self,
message: object,
) -> Iterator[Tuple[str, AttributeValue]]:
if role := getattr(message, "role", None):
yield MessageAttributes.MESSAGE_ROLE, role
if content := getattr(message, "content", None):
yield MessageAttributes.MESSAGE_CONTENT, content
if function_call := getattr(message, "function_call", None):
if name := getattr(function_call, "name", None):
yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, name
if arguments := getattr(function_call, "arguments", None):
yield MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, arguments
if (tool_calls := getattr(message, "tool_calls", None)) and isinstance(
tool_calls, Iterable
):
for index, tool_call in enumerate(tool_calls):
if (tool_call_id := getattr(tool_call, "id", None)) is not None:
# https://github.com/groq/groq-python/blob/fa2e13b5747d18aeb478700f1e5426af2fd087a1/src/groq/types/chat/chat_completion_tool_message_param.py#L17 # noqa: E501
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_ID}",
tool_call_id,
)
if function := getattr(tool_call, "function", None):
# https://github.com/groq/groq-python/blob/fa2e13b5747d18aeb478700f1e5426af2fd087a1/src/groq/types/chat/chat_completion_message_tool_call.py#L10 # noqa: E501
if name := getattr(function, "name", None):
yield (
(
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}"
),
name,
)
if arguments := getattr(function, "arguments", None):
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
arguments,
)

def _get_attributes_from_completion_usage(
self,
usage: object,
) -> Iterator[Tuple[str, AttributeValue]]:
if (total_tokens := getattr(usage, "total_tokens", None)) is not None:
yield SpanAttributes.LLM_TOKEN_COUNT_TOTAL, total_tokens
if (prompt_tokens := getattr(usage, "prompt_tokens", None)) is not None:
yield SpanAttributes.LLM_TOKEN_COUNT_PROMPT, prompt_tokens
if (completion_tokens := getattr(usage, "completion_tokens", None)) is not None:
yield SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, completion_tokens
Loading
Loading