From 061daa216ebe65a77ba16b1ad1937f49dda3535a Mon Sep 17 00:00:00 2001 From: Chris Park Date: Mon, 9 Dec 2024 16:45:30 -0500 Subject: [PATCH] groq tool support --- .../examples/chat_completions_with_tool.py | 119 +++++++++ .../examples/tool_call.py | 41 --- .../instrumentation/groq/__init__.py | 5 +- .../groq/_request_attributes_extractor.py | 71 ++++- .../groq/_response_attributes_extractor.py | 98 +++++++ .../instrumentation/groq/_utils.py | 53 +++- .../instrumentation/groq/_wrappers.py | 88 +++---- .../tests/test_instrumentor.py | 6 +- .../tests/test_tool_calls.py | 245 ++++++++++++++++++ 9 files changed, 602 insertions(+), 124 deletions(-) create mode 100644 python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py delete mode 100644 python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py create mode 100644 python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py create mode 100644 python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py diff --git a/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py b/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py new file mode 100644 index 000000000..5ab79754d --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py @@ -0,0 +1,119 @@ +import asyncio +import os + +from groq import AsyncGroq, Groq +from groq.types.chat import ChatCompletionToolParam +from phoenix.otel import register + +from openinference.instrumentation.groq import GroqInstrumentor + + +def test(): + client = Groq( + api_key=os.environ.get("GROQ_API_KEY"), + ) + + weather_function = { + "type": "function", + "function": { + "name": "get_weather", + "description": "finds the weather for a given city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'London'", + } + }, + "required": ["city"], + }, + }, + } + + sys_prompt = "Respond to the user's query using the correct tool." + user_msg = "What's the weather like in San Francisco?" + + messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_msg}] + response = client.chat.completions.create( + model="mixtral-8x7b-32768", + messages=messages, + temperature=0.0, + tools=[weather_function], + tool_choice="required", + ) + + message = response.choices[0].message + assert (tool_calls := message.tool_calls) + tool_call_id = tool_calls[0].id + messages.append(message) + messages.append( + ChatCompletionToolParam(content="sunny", role="tool", tool_call_id=tool_call_id), + ) + final_response = client.chat.completions.create( + model="mixtral-8x7b-32768", + messages=messages, + ) + return final_response + + +async def async_test(): + client = AsyncGroq( + api_key=os.environ.get("GROQ_API_KEY"), + ) + + weather_function = { + "type": "function", + "function": { + "name": "get_weather", + "description": "finds the weather for a given city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'London'", + } + }, + "required": ["city"], + }, + }, + } + + sys_prompt = "Respond to the user's query using the correct tool." + user_msg = "What's the weather like in San Francisco?" + + messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_msg}] + response = await client.chat.completions.create( + model="mixtral-8x7b-32768", + messages=messages, + temperature=0.0, + tools=[weather_function], + tool_choice="required", + ) + + message = response.choices[0].message + assert (tool_calls := message.tool_calls) + tool_call_id = tool_calls[0].id + messages.append(message) + messages.append( + ChatCompletionToolParam(content="sunny", role="tool", tool_call_id=tool_call_id), + ) + final_response = await client.chat.completions.create( + model="mixtral-8x7b-32768", + messages=messages, + ) + return final_response + + +if __name__ == "__main__": + tracer_provider = register(project_name="groq_debug") + GroqInstrumentor().instrument(tracer_provider=tracer_provider) + + response = test() + print("Response\n--------") + print(response) + + async_response = asyncio.run(async_test()) + print("\nAsync Response\n--------") + print(async_response) diff --git a/python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py b/python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py deleted file mode 100644 index a6d67b252..000000000 --- a/python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py +++ /dev/null @@ -1,41 +0,0 @@ -import os - -from groq import Groq -from phoenix.otel import register - -from openinference.instrumentation.groq import GroqInstrumentor - - -def test(): - tracer_provider = register(project_name="groq_debug") - GroqInstrumentor().instrument(tracer_provider=tracer_provider) - - client = Groq( - api_key=os.environ.get("GROQ_API_KEY"), - ) - - hello_world = { - "type": "function", - "function": { - "name": "hello_world", - "description": ("Print 'Hello world!'"), - "parameters": {"input": "ex"}, - }, - } - - prompt = "Be a helpful assistant" - msg = "say hello world" - - chat = client.chat.completions.create( - model="mixtral-8x7b-32768", - messages=[{"role": "system", "content": prompt}, {"role": "user", "content": msg}], - temperature=0.0, - tools=[hello_world], - tool_choice="required", - ) - return chat - - -if __name__ == "__main__": - response = test() - print(response) diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py index 8ea2b2530..858c9aa70 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py @@ -3,8 +3,8 @@ from typing import Any, Collection from opentelemetry import trace as trace_api -from opentelemetry.instrumentation.instrumentor import ( # type: ignore[attr-defined] - BaseInstrumentor, +from opentelemetry.instrumentation.instrumentor import ( + BaseInstrumentor, # type: ignore[attr-defined] ) from wrapt import wrap_function_wrapper @@ -17,6 +17,7 @@ from openinference.instrumentation.groq.version import __version__ logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) _instruments = ("groq >= 0.9.0",) diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py index b90f4955f..3d677cca2 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py @@ -1,4 +1,5 @@ import logging +from enum import Enum from typing import Any, Iterable, Iterator, Mapping, Tuple, TypeVar from opentelemetry.util.types import AttributeValue @@ -9,6 +10,7 @@ MessageAttributes, OpenInferenceSpanKindValues, SpanAttributes, + ToolCallAttributes, ) __all__ = ("_RequestAttributesExtractor",) @@ -43,26 +45,73 @@ def get_extra_attributes_from_request( return invocation_params = dict(request_parameters) invocation_params.pop("messages", None) # Remove LLM input messages + invocation_params.pop("functions", None) + + if isinstance((tools := invocation_params.pop("tools", None)), Iterable): + for i, tool in enumerate(tools): + yield f"llm.tools.{i}.tool.json_schema", safe_json_dumps(tool) yield SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_params) if (input_messages := request_parameters.get("messages")) and isinstance( input_messages, Iterable ): - # Use reversed() to get the last message first. This is because OTEL has a default - # limit of 128 attributes per span, and flattening increases the number of - # attributes very quickly. for index, input_message in reversed(list(enumerate(input_messages))): - if role := input_message.get("role"): - yield ( - f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{MessageAttributes.MESSAGE_ROLE}", - role, - ) - if content := input_message.get("content"): + # Use reversed() to get the last message first. This is because OTEL has a default + # limit of 128 attributes per span, and flattening increases the number of + # attributes very quickly. + for key, value in self._get_attributes_from_message_param(input_message): + yield f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{key}", value + + def _get_attributes_from_message_param( + self, + message: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + if not hasattr(message, "get"): + return + if role := message.get("role"): + yield ( + MessageAttributes.MESSAGE_ROLE, + role.value if isinstance(role, Enum) else role, + ) + if content := message.get("content"): + yield ( + MessageAttributes.MESSAGE_CONTENT, + content, + ) + if name := message.get("name"): + yield MessageAttributes.MESSAGE_NAME, name + if (function_call := message.get("function_call")) and hasattr(function_call, "get"): + if function_name := function_call.get("name"): + yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, function_name + if function_arguments := function_call.get("arguments"): + yield ( + MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, + function_arguments, + ) + if (tool_calls := message.get("tool_calls")) and isinstance(tool_calls, Iterable): + for index, tool_call in enumerate(tool_calls): + if not hasattr(tool_call, "get"): + continue + if (tool_call_id := tool_call.get("id")) is not None: yield ( - f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{MessageAttributes.MESSAGE_CONTENT}", - content, + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_ID}", + tool_call_id, ) + if (function := tool_call.get("function")) and hasattr(function, "get"): + if name := function.get("name"): + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", + name, + ) + if arguments := function.get("arguments"): + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", + arguments, + ) T = TypeVar("T", bound=type) diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py new file mode 100644 index 000000000..e9599314e --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py @@ -0,0 +1,98 @@ +import logging +from typing import Any, Iterable, Iterator, Mapping, Tuple + +from opentelemetry.util.types import AttributeValue + +from openinference.instrumentation.groq._utils import _as_output_attributes, _io_value_and_type +from openinference.semconv.trace import MessageAttributes, SpanAttributes, ToolCallAttributes + +__all__ = ("_ResponseAttributesExtractor",) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class _ResponseAttributesExtractor: + __slots__ = () + + def get_attributes(self, response: Any) -> Iterator[Tuple[str, AttributeValue]]: + yield from _as_output_attributes( + _io_value_and_type(response), + ) + + def get_extra_attributes( + self, + response: Any, + request_parameters: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + yield from self._get_attributes_from_chat_completion( + completion=response, + request_parameters=request_parameters, + ) + + def _get_attributes_from_chat_completion( + self, + completion: Any, + request_parameters: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + if model := getattr(completion, "model", None): + yield SpanAttributes.LLM_MODEL_NAME, model + if usage := getattr(completion, "usage", None): + yield from self._get_attributes_from_completion_usage(usage) + if (choices := getattr(completion, "choices", None)) and isinstance(choices, Iterable): + for choice in choices: + if (index := getattr(choice, "index", None)) is None: + continue + if message := getattr(choice, "message", None): + for key, value in self._get_attributes_from_chat_completion_message(message): + yield f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{index}.{key}", value + + def _get_attributes_from_chat_completion_message( + self, + message: object, + ) -> Iterator[Tuple[str, AttributeValue]]: + if role := getattr(message, "role", None): + yield MessageAttributes.MESSAGE_ROLE, role + if content := getattr(message, "content", None): + yield MessageAttributes.MESSAGE_CONTENT, content + if function_call := getattr(message, "function_call", None): + if name := getattr(function_call, "name", None): + yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, name + if arguments := getattr(function_call, "arguments", None): + yield MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, arguments + if (tool_calls := getattr(message, "tool_calls", None)) and isinstance( + tool_calls, Iterable + ): + for index, tool_call in enumerate(tool_calls): + if (tool_call_id := getattr(tool_call, "id", None)) is not None: + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_ID}", + tool_call_id, + ) + if function := getattr(tool_call, "function", None): + if name := getattr(function, "name", None): + yield ( + ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}" + ), + name, + ) + if arguments := getattr(function, "arguments", None): + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", + arguments, + ) + + def _get_attributes_from_completion_usage( + self, + usage: object, + ) -> Iterator[Tuple[str, AttributeValue]]: + if (total_tokens := getattr(usage, "total_tokens", None)) is not None: + yield SpanAttributes.LLM_TOKEN_COUNT_TOTAL, total_tokens + if (prompt_tokens := getattr(usage, "prompt_tokens", None)) is not None: + yield SpanAttributes.LLM_TOKEN_COUNT_PROMPT, prompt_tokens + if (completion_tokens := getattr(usage, "completion_tokens", None)) is not None: + yield SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, completion_tokens diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py index 468c47701..39b9e825e 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py @@ -1,18 +1,12 @@ import logging import warnings -from typing import ( - Any, - Iterator, - Mapping, - NamedTuple, - Optional, - Sequence, - Tuple, -) - -from opentelemetry.util.types import AttributeValue +from typing import Any, Iterable, Iterator, Mapping, NamedTuple, Optional, Sequence, Tuple + +from opentelemetry import trace as trace_api +from opentelemetry.util.types import Attributes, AttributeValue from openinference.instrumentation import safe_json_dumps +from openinference.instrumentation.groq._with_span import _WithSpan from openinference.semconv.trace import OpenInferenceMimeTypeValues, SpanAttributes logger = logging.getLogger(__name__) @@ -55,3 +49,40 @@ def _as_input_attributes( # It's assumed to be TEXT by default, so we can skip to save one attribute. if value_and_type.type is not OpenInferenceMimeTypeValues.TEXT: yield SpanAttributes.INPUT_MIME_TYPE, value_and_type.type.value + + +def _as_output_attributes( + value_and_type: Optional[_ValueAndType], +) -> Iterator[Tuple[str, AttributeValue]]: + if not value_and_type: + return + yield SpanAttributes.OUTPUT_VALUE, value_and_type.value + # It's assumed to be TEXT by default, so we can skip to save one attribute. + if value_and_type.type is not OpenInferenceMimeTypeValues.TEXT: + yield SpanAttributes.OUTPUT_MIME_TYPE, value_and_type.type.value + + +def _finish_tracing( + with_span: _WithSpan, + attributes: Iterable[Tuple[str, AttributeValue]], + extra_attributes: Iterable[Tuple[str, AttributeValue]], + status: Optional[trace_api.Status] = None, +) -> None: + try: + attributes: Attributes = dict(attributes) + except Exception: + logger.exception("Failed to get attributes") + attributes = None + try: + extra_attributes: Attributes = dict(extra_attributes) + except Exception: + logger.exception("Failed to get extra attributes") + extra_attributes = None + try: + with_span.finish_tracing( + status=status, + attributes=attributes, + extra_attributes=extra_attributes, + ) + except Exception: + logger.exception("Failed to finish tracing") diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py index 70a8c9442..7dc4426d8 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py @@ -1,4 +1,5 @@ import json +import logging from abc import ABC from contextlib import contextmanager from enum import Enum @@ -14,15 +15,21 @@ from openinference.instrumentation.groq._request_attributes_extractor import ( _RequestAttributesExtractor, ) +from openinference.instrumentation.groq._response_attributes_extractor import ( + _ResponseAttributesExtractor, +) +from openinference.instrumentation.groq._utils import _finish_tracing from openinference.instrumentation.groq._with_span import _WithSpan from openinference.semconv.trace import ( EmbeddingAttributes, MessageAttributes, - OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes, ) +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + def _flatten(mapping: Mapping[str, Any]) -> Iterator[Tuple[str, AttributeValue]]: for key, value in mapping.items(): @@ -111,6 +118,7 @@ class _CompletionsWrapper(_WithTracer): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._request_extractor = _RequestAttributesExtractor() + self._response_extractor = _ResponseAttributesExtractor() def __call__( self, @@ -141,31 +149,26 @@ def __call__( try: response = wrapped(*args, **kwargs) except Exception as exception: - response = exception span.record_exception(exception) status = trace_api.Status( status_code=trace_api.StatusCode.ERROR, description=f"{type(exception).__name__}: {exception}", ) span.finish_tracing(status=status) - else: - content = response.choices[0].message.content - span.set_attributes( - dict( - _flatten( - { - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}": content, - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}": "assistant", - SpanAttributes.OUTPUT_VALUE: response.choices[0].message.content, - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON, - LLM_TOKEN_COUNT_COMPLETION: response.usage.completion_tokens, - SpanAttributes.LLM_TOKEN_COUNT_PROMPT: response.usage.prompt_tokens, - SpanAttributes.LLM_TOKEN_COUNT_TOTAL: response.usage.total_tokens, - } - ) - ) + raise + try: + _finish_tracing( + status=trace_api.Status(status_code=trace_api.StatusCode.OK), + with_span=span, + attributes=self._response_extractor.get_attributes(response=response), + extra_attributes=self._response_extractor.get_extra_attributes( + response=response, request_parameters=request_parameters + ), ) - span.finish_tracing(status=trace_api.Status(trace_api.StatusCode.OK)) + except Exception: + logger.exception(f"Failed to finalize response of type {type(response)}") + span.finish_tracing() + return response class _AsyncCompletionsWrapper(_WithTracer): @@ -177,6 +180,7 @@ class _AsyncCompletionsWrapper(_WithTracer): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._request_extractor = _RequestAttributesExtractor() + self._response_extractor = _ResponseAttributesExtractor() async def __call__( self, @@ -195,8 +199,6 @@ async def __call__( invocation_parameters.update(arg) invocation_parameters.update(kwargs) request_parameters = _parse_args(signature(wrapped), *args, **kwargs) - llm_invocation_params = kwargs - llm_messages = dict(kwargs).pop("messages", None) span_name = "AsyncCompletions" with self._start_as_current_span( @@ -207,25 +209,9 @@ async def __call__( request_parameters ), ) as span: - span.set_attributes( - dict( - _flatten( - { - SpanAttributes.OPENINFERENCE_SPAN_KIND: LLM, - SpanAttributes.LLM_INPUT_MESSAGES: llm_messages, - SpanAttributes.LLM_INVOCATION_PARAMETERS: safe_json_dumps( - llm_invocation_params - ), - SpanAttributes.LLM_MODEL_NAME: llm_invocation_params.get("model"), - SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON, - } - ) - ) - ) try: response = await wrapped(*args, **kwargs) except Exception as exception: - response = exception span.record_exception(exception) status = trace_api.Status( status_code=trace_api.StatusCode.ERROR, @@ -233,24 +219,18 @@ async def __call__( ) span.finish_tracing(status=status) raise - else: - content = response.choices[0].message.content - span.set_attributes( - dict( - _flatten( - { - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}": content, - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}": "assistant", - SpanAttributes.OUTPUT_VALUE: response.choices[0].message.content, - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON, - LLM_TOKEN_COUNT_COMPLETION: response.usage.completion_tokens, - SpanAttributes.LLM_TOKEN_COUNT_PROMPT: response.usage.prompt_tokens, - SpanAttributes.LLM_TOKEN_COUNT_TOTAL: response.usage.total_tokens, - } - ) - ) + try: + _finish_tracing( + status=trace_api.Status(status_code=trace_api.StatusCode.OK), + with_span=span, + attributes=self._response_extractor.get_attributes(response=response), + extra_attributes=self._response_extractor.get_extra_attributes( + response=response, request_parameters=request_parameters + ), ) - span.finish_tracing(status=trace_api.Status(trace_api.StatusCode.OK)) + except Exception: + logger.exception(f"Failed to finalize response of type {type(response)}") + span.finish_tracing() return response diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py index d904bb8cb..cfda542a3 100644 --- a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py @@ -398,6 +398,7 @@ async def exec_comp() -> None: ) attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) + print(attributes) assert ( attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}"] == "user" @@ -475,11 +476,6 @@ async def exec_comp() -> None: assert invocation_params["model"] == "fake_model" assert invocation_params["temperature"] == 0.0 assert invocation_params["tool_choice"] == "required" - assert invocation_params["tools"][0] == { - "type": "function", - "function": "FunctionDefinition(name='hello_world', description=\"Print 'Hello world!'\", " - "parameters={'input': 'ex'})", - } def test_groq_uninstrumentation( diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py b/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py new file mode 100644 index 000000000..18c87143f --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py @@ -0,0 +1,245 @@ +import json +from typing import Any, Optional, Type, Union, cast + +import pytest +from groq import Groq +from groq._base_client import _StreamT +from groq._types import Body, RequestFiles, RequestOptions, ResponseT +from groq.types.chat import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionMessageToolCall, + ChatCompletionToolParam, +) +from groq.types.chat.chat_completion import Choice, CompletionUsage +from groq.types.chat.chat_completion_message_tool_call import Function +from opentelemetry import trace as trace_api +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + +def create_mock_tool_completion(messages): + last_user_message = next(msg for msg in messages[::-1] if msg.get("role") == "user") + city = last_user_message["content"].split(" in ")[-1].split("?")[0].strip() + + # Create tool calls with dynamically generated IDs + tool_calls = [ + ChatCompletionMessageToolCall( + id=f"call_{62136355 + i}", # Use a base ID and increment + function=Function(arguments=json.dumps({"city": city}), name=tool_name), + type="function", + ) + for i, tool_name in enumerate(["get_weather", "get_population"]) + ] + + return ChatCompletion( + id="chat_comp_0", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + logprobs=None, + message=ChatCompletionMessage( + content="", role="assistant", function_call=None, tool_calls=tool_calls + ), + ) + ], + created=1722531851, + model="fake_groq_model", + object="chat.completion", + system_fingerprint="fp0", + usage=CompletionUsage( + completion_tokens=379, + prompt_tokens=25, + total_tokens=404, + completion_time=0.616262398, + prompt_time=0.002549632, + queue_time=None, + total_time=0.6188120300000001, + ), + ) + + +def _mock_post( + self: Any, + path: str = "fake/url", + *, + cast_to: Type[ResponseT], + body: Optional[Body] = None, + options: RequestOptions = {}, + files: Optional[RequestFiles] = None, + stream: bool = False, + stream_cls: Optional[Type[_StreamT]] = None, +) -> Union[ResponseT, _StreamT]: + # Extract messages from the request body + messages = body.get("messages", []) if body else [] + + # Create a mock completion based on the messages + mock_completion = create_mock_tool_completion(messages) + + return cast(ResponseT, mock_completion) + + +@pytest.fixture() +def in_memory_span_exporter() -> InMemorySpanExporter: + return InMemorySpanExporter() + + +@pytest.fixture() +def tracer_provider(in_memory_span_exporter: InMemorySpanExporter) -> TracerProvider: + resource = Resource(attributes={}) + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(SimpleSpanProcessor(in_memory_span_exporter)) + return tracer_provider + + +@pytest.mark.vcr( + decode_compressed_response=True, + before_record_request=lambda _: _.headers.clear() or _, + before_record_response=lambda _: {**_, "headers": {}}, +) +def test_tool_calls( + in_memory_span_exporter: InMemorySpanExporter, + tracer_provider: trace_api.TracerProvider, +) -> None: + client = Groq(api_key="fake") + client.chat.completions._post = _mock_post # type: ignore[assignment] + + input_tools = [ + ChatCompletionToolParam( + type="function", + function={ + "name": "get_weather", + "description": "finds the weather for a given city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'London'", + } + }, + "required": ["city"], + }, + }, + ), + ChatCompletionToolParam( + type="function", + function={ + "name": "get_population", + "description": "finds the population for a given city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the population for, e.g. 'London'", + } + }, + "required": ["city"], + }, + }, + ), + ] + client.chat.completions.create( + extra_headers={"Accept-Encoding": "gzip"}, + model="fake_groq_model", + tools=input_tools, + messages=[ + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_62136355", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"city": "New York"}'}, + }, + { + "id": "call_62136356", + "type": "function", + "function": {"name": "get_population", "arguments": '{"city": "New York"}'}, + }, + ], + }, + { + "role": "tool", + "tool_call_id": "call_62136355", + "content": '{"city": "New York", "weather": "fine"}', + }, + { + "role": "tool", + "tool_call_id": "call_62136356", + "content": '{"city": "New York", "weather": "large"}', + }, + { + "role": "assistant", + "content": "In New York the weather is fine and the population is large.", + }, + { + "role": "user", + "content": "What's the weather and population in San Francisco?", + }, + ], + ) + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + attributes = dict(span.attributes or {}) + for i in range(len(input_tools)): + json_schema = attributes.pop(f"llm.tools.{i}.tool.json_schema") + assert isinstance(json_schema, str) + assert json.loads(json_schema) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.id") == "call_62136355" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.function.name") + == "get_weather" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.function.arguments") + == '{"city": "New York"}' + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.id") == "call_62136356" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.function.name") + == "get_population" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.function.arguments") + == '{"city": "New York"}' + ) + assert attributes.pop("llm.input_messages.1.message.role") == "tool" + assert attributes.pop("llm.input_messages.1.message.tool_call_id") == "call_62136355" + assert ( + attributes.pop("llm.input_messages.1.message.content") + == '{"city": "New York", "weather": "fine"}' + ) + assert attributes.pop("llm.input_messages.2.message.role") == "tool" + assert attributes.pop("llm.input_messages.2.message.tool_call_id") == "call_62136356" + assert ( + attributes.pop("llm.input_messages.2.message.content") + == '{"city": "New York", "weather": "large"}' + ) + assert attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id") + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.function.name") + == "get_weather" + ) + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.function.arguments") + == '{"city": "San Francisco"}' + ) + assert attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.id") + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.function.name") + == "get_population" + ) + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.function.arguments") + == '{"city": "San Francisco"}' + )