Skip to content

Commit

Permalink
groq tool support
Browse files Browse the repository at this point in the history
  • Loading branch information
cjunkin committed Dec 9, 2024
1 parent 4467562 commit 061daa2
Show file tree
Hide file tree
Showing 9 changed files with 602 additions and 124 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import asyncio
import os

from groq import AsyncGroq, Groq
from groq.types.chat import ChatCompletionToolParam
from phoenix.otel import register

from openinference.instrumentation.groq import GroqInstrumentor


def test():
client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)

weather_function = {
"type": "function",
"function": {
"name": "get_weather",
"description": "finds the weather for a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'London'",
}
},
"required": ["city"],
},
},
}

sys_prompt = "Respond to the user's query using the correct tool."
user_msg = "What's the weather like in San Francisco?"

messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_msg}]
response = client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
temperature=0.0,
tools=[weather_function],
tool_choice="required",
)

message = response.choices[0].message
assert (tool_calls := message.tool_calls)
tool_call_id = tool_calls[0].id
messages.append(message)
messages.append(
ChatCompletionToolParam(content="sunny", role="tool", tool_call_id=tool_call_id),
)
final_response = client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
)
return final_response


async def async_test():
client = AsyncGroq(
api_key=os.environ.get("GROQ_API_KEY"),
)

weather_function = {
"type": "function",
"function": {
"name": "get_weather",
"description": "finds the weather for a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'London'",
}
},
"required": ["city"],
},
},
}

sys_prompt = "Respond to the user's query using the correct tool."
user_msg = "What's the weather like in San Francisco?"

messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_msg}]
response = await client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
temperature=0.0,
tools=[weather_function],
tool_choice="required",
)

message = response.choices[0].message
assert (tool_calls := message.tool_calls)
tool_call_id = tool_calls[0].id
messages.append(message)
messages.append(
ChatCompletionToolParam(content="sunny", role="tool", tool_call_id=tool_call_id),
)
final_response = await client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
)
return final_response


if __name__ == "__main__":
tracer_provider = register(project_name="groq_debug")
GroqInstrumentor().instrument(tracer_provider=tracer_provider)

response = test()
print("Response\n--------")
print(response)

async_response = asyncio.run(async_test())
print("\nAsync Response\n--------")
print(async_response)

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, Collection

from opentelemetry import trace as trace_api
from opentelemetry.instrumentation.instrumentor import ( # type: ignore[attr-defined]
BaseInstrumentor,
from opentelemetry.instrumentation.instrumentor import (
BaseInstrumentor, # type: ignore[attr-defined]
)
from wrapt import wrap_function_wrapper

Expand All @@ -17,6 +17,7 @@
from openinference.instrumentation.groq.version import __version__

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

_instruments = ("groq >= 0.9.0",)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from enum import Enum
from typing import Any, Iterable, Iterator, Mapping, Tuple, TypeVar

from opentelemetry.util.types import AttributeValue
Expand All @@ -9,6 +10,7 @@
MessageAttributes,
OpenInferenceSpanKindValues,
SpanAttributes,
ToolCallAttributes,
)

__all__ = ("_RequestAttributesExtractor",)
Expand Down Expand Up @@ -43,26 +45,73 @@ def get_extra_attributes_from_request(
return
invocation_params = dict(request_parameters)
invocation_params.pop("messages", None) # Remove LLM input messages
invocation_params.pop("functions", None)

if isinstance((tools := invocation_params.pop("tools", None)), Iterable):
for i, tool in enumerate(tools):
yield f"llm.tools.{i}.tool.json_schema", safe_json_dumps(tool)

yield SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_params)

if (input_messages := request_parameters.get("messages")) and isinstance(
input_messages, Iterable
):
# Use reversed() to get the last message first. This is because OTEL has a default
# limit of 128 attributes per span, and flattening increases the number of
# attributes very quickly.
for index, input_message in reversed(list(enumerate(input_messages))):
if role := input_message.get("role"):
yield (
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{MessageAttributes.MESSAGE_ROLE}",
role,
)
if content := input_message.get("content"):
# Use reversed() to get the last message first. This is because OTEL has a default
# limit of 128 attributes per span, and flattening increases the number of
# attributes very quickly.
for key, value in self._get_attributes_from_message_param(input_message):
yield f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{key}", value

def _get_attributes_from_message_param(
self,
message: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
if not hasattr(message, "get"):
return
if role := message.get("role"):
yield (
MessageAttributes.MESSAGE_ROLE,
role.value if isinstance(role, Enum) else role,
)
if content := message.get("content"):
yield (
MessageAttributes.MESSAGE_CONTENT,
content,
)
if name := message.get("name"):
yield MessageAttributes.MESSAGE_NAME, name
if (function_call := message.get("function_call")) and hasattr(function_call, "get"):
if function_name := function_call.get("name"):
yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, function_name
if function_arguments := function_call.get("arguments"):
yield (
MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON,
function_arguments,
)
if (tool_calls := message.get("tool_calls")) and isinstance(tool_calls, Iterable):
for index, tool_call in enumerate(tool_calls):
if not hasattr(tool_call, "get"):
continue
if (tool_call_id := tool_call.get("id")) is not None:
yield (
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{MessageAttributes.MESSAGE_CONTENT}",
content,
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_ID}",
tool_call_id,
)
if (function := tool_call.get("function")) and hasattr(function, "get"):
if name := function.get("name"):
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
name,
)
if arguments := function.get("arguments"):
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
arguments,
)


T = TypeVar("T", bound=type)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import logging
from typing import Any, Iterable, Iterator, Mapping, Tuple

from opentelemetry.util.types import AttributeValue

from openinference.instrumentation.groq._utils import _as_output_attributes, _io_value_and_type
from openinference.semconv.trace import MessageAttributes, SpanAttributes, ToolCallAttributes

__all__ = ("_ResponseAttributesExtractor",)

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


class _ResponseAttributesExtractor:
__slots__ = ()

def get_attributes(self, response: Any) -> Iterator[Tuple[str, AttributeValue]]:
yield from _as_output_attributes(
_io_value_and_type(response),
)

def get_extra_attributes(
self,
response: Any,
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
yield from self._get_attributes_from_chat_completion(
completion=response,
request_parameters=request_parameters,
)

def _get_attributes_from_chat_completion(
self,
completion: Any,
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
if model := getattr(completion, "model", None):
yield SpanAttributes.LLM_MODEL_NAME, model
if usage := getattr(completion, "usage", None):
yield from self._get_attributes_from_completion_usage(usage)
if (choices := getattr(completion, "choices", None)) and isinstance(choices, Iterable):
for choice in choices:
if (index := getattr(choice, "index", None)) is None:
continue
if message := getattr(choice, "message", None):
for key, value in self._get_attributes_from_chat_completion_message(message):
yield f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{index}.{key}", value

def _get_attributes_from_chat_completion_message(
self,
message: object,
) -> Iterator[Tuple[str, AttributeValue]]:
if role := getattr(message, "role", None):
yield MessageAttributes.MESSAGE_ROLE, role
if content := getattr(message, "content", None):
yield MessageAttributes.MESSAGE_CONTENT, content
if function_call := getattr(message, "function_call", None):
if name := getattr(function_call, "name", None):
yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, name
if arguments := getattr(function_call, "arguments", None):
yield MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, arguments
if (tool_calls := getattr(message, "tool_calls", None)) and isinstance(
tool_calls, Iterable
):
for index, tool_call in enumerate(tool_calls):
if (tool_call_id := getattr(tool_call, "id", None)) is not None:
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_ID}",
tool_call_id,
)
if function := getattr(tool_call, "function", None):
if name := getattr(function, "name", None):
yield (
(
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}"
),
name,
)
if arguments := getattr(function, "arguments", None):
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
arguments,
)

def _get_attributes_from_completion_usage(
self,
usage: object,
) -> Iterator[Tuple[str, AttributeValue]]:
if (total_tokens := getattr(usage, "total_tokens", None)) is not None:
yield SpanAttributes.LLM_TOKEN_COUNT_TOTAL, total_tokens
if (prompt_tokens := getattr(usage, "prompt_tokens", None)) is not None:
yield SpanAttributes.LLM_TOKEN_COUNT_PROMPT, prompt_tokens
if (completion_tokens := getattr(usage, "completion_tokens", None)) is not None:
yield SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, completion_tokens
Loading

0 comments on commit 061daa2

Please sign in to comment.