From 411ad95ee4681e8f440cb6d92c347513803647ea Mon Sep 17 00:00:00 2001 From: Chris Park Date: Fri, 22 Nov 2024 20:08:56 -0500 Subject: [PATCH 1/7] add support for groq tool call details --- .../examples/tool_call.py | 41 ++++ .../groq/_request_attributes_extractor.py | 15 +- .../instrumentation/groq/_wrappers.py | 65 +++--- .../tests/test_instrumentor.py | 194 ++++++++++++++++-- 4 files changed, 258 insertions(+), 57 deletions(-) create mode 100644 python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py diff --git a/python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py b/python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py new file mode 100644 index 000000000..a6d67b252 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py @@ -0,0 +1,41 @@ +import os + +from groq import Groq +from phoenix.otel import register + +from openinference.instrumentation.groq import GroqInstrumentor + + +def test(): + tracer_provider = register(project_name="groq_debug") + GroqInstrumentor().instrument(tracer_provider=tracer_provider) + + client = Groq( + api_key=os.environ.get("GROQ_API_KEY"), + ) + + hello_world = { + "type": "function", + "function": { + "name": "hello_world", + "description": ("Print 'Hello world!'"), + "parameters": {"input": "ex"}, + }, + } + + prompt = "Be a helpful assistant" + msg = "say hello world" + + chat = client.chat.completions.create( + model="mixtral-8x7b-32768", + messages=[{"role": "system", "content": prompt}, {"role": "user", "content": msg}], + temperature=0.0, + tools=[hello_world], + tool_choice="required", + ) + return chat + + +if __name__ == "__main__": + response = test() + print(response) diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py index 8b09f1496..b90f4955f 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py @@ -1,12 +1,5 @@ import logging -from typing import ( - Any, - Iterable, - Iterator, - Mapping, - Tuple, - TypeVar, -) +from typing import Any, Iterable, Iterator, Mapping, Tuple, TypeVar from opentelemetry.util.types import AttributeValue @@ -49,10 +42,10 @@ def get_extra_attributes_from_request( if not isinstance(request_parameters, Mapping): return invocation_params = dict(request_parameters) - invocation_params.pop("messages", None) - invocation_params.pop("functions", None) - invocation_params.pop("tools", None) + invocation_params.pop("messages", None) # Remove LLM input messages + yield SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_params) + if (input_messages := request_parameters.get("messages")) and isinstance( input_messages, Iterable ): diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py index ca3bf2263..70a8c9442 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py @@ -141,30 +141,31 @@ def __call__( try: response = wrapped(*args, **kwargs) except Exception as exception: + response = exception span.record_exception(exception) status = trace_api.Status( status_code=trace_api.StatusCode.ERROR, description=f"{type(exception).__name__}: {exception}", ) span.finish_tracing(status=status) - content = response.choices[0].message.content - span.set_attributes( - dict( - _flatten( - { - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}": content, - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}": "assistant", - SpanAttributes.OUTPUT_VALUE: response.choices[0].message.content, - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON, - LLM_TOKEN_COUNT_COMPLETION: response.usage.completion_tokens, - SpanAttributes.LLM_TOKEN_COUNT_PROMPT: response.usage.prompt_tokens, - SpanAttributes.LLM_TOKEN_COUNT_TOTAL: response.usage.total_tokens, - } + else: + content = response.choices[0].message.content + span.set_attributes( + dict( + _flatten( + { + f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}": content, + f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}": "assistant", + SpanAttributes.OUTPUT_VALUE: response.choices[0].message.content, + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON, + LLM_TOKEN_COUNT_COMPLETION: response.usage.completion_tokens, + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: response.usage.prompt_tokens, + SpanAttributes.LLM_TOKEN_COUNT_TOTAL: response.usage.total_tokens, + } + ) ) ) - ) - span.finish_tracing(status=trace_api.Status(trace_api.StatusCode.OK)) - return response + span.finish_tracing(status=trace_api.Status(trace_api.StatusCode.OK)) class _AsyncCompletionsWrapper(_WithTracer): @@ -224,6 +225,7 @@ async def __call__( try: response = await wrapped(*args, **kwargs) except Exception as exception: + response = exception span.record_exception(exception) status = trace_api.Status( status_code=trace_api.StatusCode.ERROR, @@ -231,23 +233,24 @@ async def __call__( ) span.finish_tracing(status=status) raise - content = response.choices[0].message.content - span.set_attributes( - dict( - _flatten( - { - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}": content, - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}": "assistant", - SpanAttributes.OUTPUT_VALUE: response.choices[0].message.content, - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON, - LLM_TOKEN_COUNT_COMPLETION: response.usage.completion_tokens, - SpanAttributes.LLM_TOKEN_COUNT_PROMPT: response.usage.prompt_tokens, - SpanAttributes.LLM_TOKEN_COUNT_TOTAL: response.usage.total_tokens, - } + else: + content = response.choices[0].message.content + span.set_attributes( + dict( + _flatten( + { + f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}": content, + f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}": "assistant", + SpanAttributes.OUTPUT_VALUE: response.choices[0].message.content, + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON, + LLM_TOKEN_COUNT_COMPLETION: response.usage.completion_tokens, + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: response.usage.prompt_tokens, + SpanAttributes.LLM_TOKEN_COUNT_TOTAL: response.usage.total_tokens, + } + ) ) ) - ) - span.finish_tracing(status=trace_api.Status(trace_api.StatusCode.OK)) + span.finish_tracing(status=trace_api.Status(trace_api.StatusCode.OK)) return response diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py index 23a38f7b7..50b6f8e71 100644 --- a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py @@ -1,15 +1,6 @@ import asyncio -from typing import ( - Any, - Dict, - Generator, - List, - Mapping, - Optional, - Type, - Union, - cast, -) +import json +from typing import Any, Dict, Generator, List, Mapping, Optional, Type, Union, cast import pytest from groq import AsyncGroq, Groq @@ -21,6 +12,10 @@ ChatCompletionMessage, Choice, ) +from groq.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) from groq.types.completion_usage import CompletionUsage from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider @@ -30,10 +25,7 @@ from openinference.instrumentation import OITracer, using_attributes from openinference.instrumentation.groq import GroqInstrumentor -from openinference.semconv.trace import ( - MessageAttributes, - SpanAttributes, -) +from openinference.semconv.trace import MessageAttributes, SpanAttributes mock_completion = ChatCompletion( id="chat_comp_0", @@ -62,6 +54,51 @@ ), ) +mock_tool_completion = ChatCompletion( + id="chat_comp_0", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + logprobs=None, + message=ChatCompletionMessage( + content="", + role="assistant", + function_call=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_t760", + function=Function(arguments="{}", name="hello_world"), + type="function", + ) + ], + ), + ) + ], + created=1722531851, + model="fake_model", + object="chat.completion", + system_fingerprint="fp0", + usage=CompletionUsage( + completion_tokens=379, + prompt_tokens=25, + total_tokens=404, + completion_time=0.616262398, + prompt_time=0.002549632, + queue_time=None, + total_time=0.6188120300000001, + ), +) + +test_tool = { + "type": "function", + "function": { + "name": "hello_world", + "description": ("Print 'Hello world!'"), + "parameters": {"input": "ex"}, + }, +} + def _mock_post( self: Any, @@ -250,6 +287,68 @@ def test_groq_instrumentation( ) +def test_groq_tool_call( + tracer_provider: TracerProvider, + in_memory_span_exporter: InMemorySpanExporter, + setup_groq_instrumentation: Any, + session_id: str, + user_id: str, + metadata: Dict[str, Any], + tags: List[str], + prompt_template: str, + prompt_template_version: str, + prompt_template_variables: Dict[str, Any], +) -> None: + client = Groq(api_key="fake") + client.chat.completions._post = _mock_post # type: ignore[assignment] + + with using_attributes( + session_id=session_id, + user_id=user_id, + metadata=metadata, + tags=tags, + prompt_template=prompt_template, + prompt_template_version=prompt_template_version, + prompt_template_variables=prompt_template_variables, + ): + client.chat.completions.create( + messages=[ + { + "role": "user", + "content": "Print hello world", + } + ], + model="fake_model", + temperature=0.0, + tools=[test_tool], + tool_choice="required", + ) + spans = in_memory_span_exporter.get_finished_spans() + + assert spans[0].name == "Completions" + assert spans[0].attributes and spans[0].attributes.get("openinference.span.kind") == "LLM" + + for span in spans: + att = span.attributes + _check_context_attributes( + att, + ) + + attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) + + invocation_params = json.loads(attributes.get(SpanAttributes.LLM_INVOCATION_PARAMETERS, {})) + assert invocation_params["model"] == "fake_model" + assert invocation_params["tool_choice"] == "required" + assert invocation_params["tools"][0] == { + "function": { + "description": "Print 'Hello world!'", + "name": "hello_world", + "parameters": {"input": "ex"}, + }, + "type": "function", + } + + def test_groq_async_instrumentation( tracer_provider: TracerProvider, in_memory_span_exporter: InMemorySpanExporter, @@ -317,6 +416,71 @@ async def exec_comp() -> None: ) +def test_groq_async_tool_call( + tracer_provider: TracerProvider, + in_memory_span_exporter: InMemorySpanExporter, + setup_groq_instrumentation: Any, + session_id: str, + user_id: str, + metadata: Dict[str, Any], + tags: List[str], + prompt_template: str, + prompt_template_version: str, + prompt_template_variables: Dict[str, Any], +) -> None: + client = AsyncGroq(api_key="fake") + client.chat.completions._post = _async_mock_post # type: ignore[assignment] + + async def exec_comp() -> None: + await client.chat.completions.create( + messages=[ + { + "role": "user", + "content": "Print hello world", + } + ], + model="fake_model", + temperature=0.0, + tools=[test_tool], + tool_choice="required", + ) + + with using_attributes( + session_id=session_id, + user_id=user_id, + metadata=metadata, + tags=tags, + prompt_template=prompt_template, + prompt_template_version=prompt_template_version, + prompt_template_variables=prompt_template_variables, + ): + asyncio.run(exec_comp()) + + spans = in_memory_span_exporter.get_finished_spans() + + assert spans[0].name == "AsyncCompletions" + assert spans[0].attributes and spans[0].attributes.get("openinference.span.kind") == "LLM" + + for span in spans: + att = span.attributes + _check_context_attributes( + att, + ) + + attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) + invocation_params = json.loads(attributes.get(SpanAttributes.LLM_INVOCATION_PARAMETERS, {})) + assert invocation_params["model"] == "fake_model" + assert invocation_params["tool_choice"] == "required" + assert invocation_params["tools"][0] == { + "function": { + "description": "Print 'Hello world!'", + "name": "hello_world", + "parameters": {"input": "ex"}, + }, + "type": "function", + } + + def test_groq_uninstrumentation( tracer_provider: TracerProvider, ) -> None: From a1125211b98947f95f188a33e2636fabb196fe14 Mon Sep 17 00:00:00 2001 From: Chris Park Date: Mon, 25 Nov 2024 11:03:41 -0500 Subject: [PATCH 2/7] updated tests --- .../tests/test_instrumentor.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py index 50b6f8e71..264435683 100644 --- a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py @@ -7,6 +7,7 @@ from groq._base_client import _StreamT from groq._types import Body, RequestFiles, RequestOptions, ResponseT from groq.resources.chat.completions import AsyncCompletions, Completions +from groq.types.chat import ChatCompletionToolParam from groq.types.chat.chat_completion import ( # type: ignore[attr-defined] ChatCompletion, ChatCompletionMessage, @@ -17,6 +18,7 @@ Function, ) from groq.types.completion_usage import CompletionUsage +from groq.types.shared import FunctionDefinition from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor @@ -90,14 +92,12 @@ ), ) -test_tool = { - "type": "function", - "function": { - "name": "hello_world", - "description": ("Print 'Hello world!'"), - "parameters": {"input": "ex"}, - }, -} +test_tool = ChatCompletionToolParam( + type="function", + function=FunctionDefinition( + name="hello_world", description=("Print 'Hello world!'"), parameters={"input": "ex"} + ), # type: ignore +) def _mock_post( @@ -336,16 +336,16 @@ def test_groq_tool_call( attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) - invocation_params = json.loads(attributes.get(SpanAttributes.LLM_INVOCATION_PARAMETERS, {})) + invocation_params: dict[str, Any] = json.loads( + attributes.get(SpanAttributes.LLM_INVOCATION_PARAMETERS, "{}") # type: ignore + ) assert invocation_params["model"] == "fake_model" + assert invocation_params["temperature"] == 0.0 assert invocation_params["tool_choice"] == "required" assert invocation_params["tools"][0] == { - "function": { - "description": "Print 'Hello world!'", - "name": "hello_world", - "parameters": {"input": "ex"}, - }, "type": "function", + "function": "FunctionDefinition(name='hello_world', description=\"Print 'Hello world!'\", " + "parameters={'input': 'ex'})", } @@ -468,16 +468,17 @@ async def exec_comp() -> None: ) attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) - invocation_params = json.loads(attributes.get(SpanAttributes.LLM_INVOCATION_PARAMETERS, {})) + invocation_params: dict[str, Any] = json.loads( + attributes.get(SpanAttributes.LLM_INVOCATION_PARAMETERS, "{}") # type: ignore + ) + assert invocation_params["model"] == "fake_model" + assert invocation_params["temperature"] == 0.0 assert invocation_params["tool_choice"] == "required" assert invocation_params["tools"][0] == { - "function": { - "description": "Print 'Hello world!'", - "name": "hello_world", - "parameters": {"input": "ex"}, - }, "type": "function", + "function": "FunctionDefinition(name='hello_world', description=\"Print 'Hello world!'\", " + "parameters={'input': 'ex'})", } From 44675629d2a517360bb034a68809e31db279e62b Mon Sep 17 00:00:00 2001 From: Chris Park Date: Mon, 25 Nov 2024 11:10:27 -0500 Subject: [PATCH 3/7] updated tests --- .../tests/test_instrumentor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py index 264435683..d904bb8cb 100644 --- a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py @@ -336,7 +336,7 @@ def test_groq_tool_call( attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) - invocation_params: dict[str, Any] = json.loads( + invocation_params = json.loads( attributes.get(SpanAttributes.LLM_INVOCATION_PARAMETERS, "{}") # type: ignore ) assert invocation_params["model"] == "fake_model" @@ -468,7 +468,7 @@ async def exec_comp() -> None: ) attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) - invocation_params: dict[str, Any] = json.loads( + invocation_params = json.loads( attributes.get(SpanAttributes.LLM_INVOCATION_PARAMETERS, "{}") # type: ignore ) From 061daa216ebe65a77ba16b1ad1937f49dda3535a Mon Sep 17 00:00:00 2001 From: Chris Park Date: Mon, 9 Dec 2024 16:45:30 -0500 Subject: [PATCH 4/7] groq tool support --- .../examples/chat_completions_with_tool.py | 119 +++++++++ .../examples/tool_call.py | 41 --- .../instrumentation/groq/__init__.py | 5 +- .../groq/_request_attributes_extractor.py | 71 ++++- .../groq/_response_attributes_extractor.py | 98 +++++++ .../instrumentation/groq/_utils.py | 53 +++- .../instrumentation/groq/_wrappers.py | 88 +++---- .../tests/test_instrumentor.py | 6 +- .../tests/test_tool_calls.py | 245 ++++++++++++++++++ 9 files changed, 602 insertions(+), 124 deletions(-) create mode 100644 python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py delete mode 100644 python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py create mode 100644 python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py create mode 100644 python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py diff --git a/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py b/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py new file mode 100644 index 000000000..5ab79754d --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py @@ -0,0 +1,119 @@ +import asyncio +import os + +from groq import AsyncGroq, Groq +from groq.types.chat import ChatCompletionToolParam +from phoenix.otel import register + +from openinference.instrumentation.groq import GroqInstrumentor + + +def test(): + client = Groq( + api_key=os.environ.get("GROQ_API_KEY"), + ) + + weather_function = { + "type": "function", + "function": { + "name": "get_weather", + "description": "finds the weather for a given city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'London'", + } + }, + "required": ["city"], + }, + }, + } + + sys_prompt = "Respond to the user's query using the correct tool." + user_msg = "What's the weather like in San Francisco?" + + messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_msg}] + response = client.chat.completions.create( + model="mixtral-8x7b-32768", + messages=messages, + temperature=0.0, + tools=[weather_function], + tool_choice="required", + ) + + message = response.choices[0].message + assert (tool_calls := message.tool_calls) + tool_call_id = tool_calls[0].id + messages.append(message) + messages.append( + ChatCompletionToolParam(content="sunny", role="tool", tool_call_id=tool_call_id), + ) + final_response = client.chat.completions.create( + model="mixtral-8x7b-32768", + messages=messages, + ) + return final_response + + +async def async_test(): + client = AsyncGroq( + api_key=os.environ.get("GROQ_API_KEY"), + ) + + weather_function = { + "type": "function", + "function": { + "name": "get_weather", + "description": "finds the weather for a given city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'London'", + } + }, + "required": ["city"], + }, + }, + } + + sys_prompt = "Respond to the user's query using the correct tool." + user_msg = "What's the weather like in San Francisco?" + + messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_msg}] + response = await client.chat.completions.create( + model="mixtral-8x7b-32768", + messages=messages, + temperature=0.0, + tools=[weather_function], + tool_choice="required", + ) + + message = response.choices[0].message + assert (tool_calls := message.tool_calls) + tool_call_id = tool_calls[0].id + messages.append(message) + messages.append( + ChatCompletionToolParam(content="sunny", role="tool", tool_call_id=tool_call_id), + ) + final_response = await client.chat.completions.create( + model="mixtral-8x7b-32768", + messages=messages, + ) + return final_response + + +if __name__ == "__main__": + tracer_provider = register(project_name="groq_debug") + GroqInstrumentor().instrument(tracer_provider=tracer_provider) + + response = test() + print("Response\n--------") + print(response) + + async_response = asyncio.run(async_test()) + print("\nAsync Response\n--------") + print(async_response) diff --git a/python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py b/python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py deleted file mode 100644 index a6d67b252..000000000 --- a/python/instrumentation/openinference-instrumentation-groq/examples/tool_call.py +++ /dev/null @@ -1,41 +0,0 @@ -import os - -from groq import Groq -from phoenix.otel import register - -from openinference.instrumentation.groq import GroqInstrumentor - - -def test(): - tracer_provider = register(project_name="groq_debug") - GroqInstrumentor().instrument(tracer_provider=tracer_provider) - - client = Groq( - api_key=os.environ.get("GROQ_API_KEY"), - ) - - hello_world = { - "type": "function", - "function": { - "name": "hello_world", - "description": ("Print 'Hello world!'"), - "parameters": {"input": "ex"}, - }, - } - - prompt = "Be a helpful assistant" - msg = "say hello world" - - chat = client.chat.completions.create( - model="mixtral-8x7b-32768", - messages=[{"role": "system", "content": prompt}, {"role": "user", "content": msg}], - temperature=0.0, - tools=[hello_world], - tool_choice="required", - ) - return chat - - -if __name__ == "__main__": - response = test() - print(response) diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py index 8ea2b2530..858c9aa70 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py @@ -3,8 +3,8 @@ from typing import Any, Collection from opentelemetry import trace as trace_api -from opentelemetry.instrumentation.instrumentor import ( # type: ignore[attr-defined] - BaseInstrumentor, +from opentelemetry.instrumentation.instrumentor import ( + BaseInstrumentor, # type: ignore[attr-defined] ) from wrapt import wrap_function_wrapper @@ -17,6 +17,7 @@ from openinference.instrumentation.groq.version import __version__ logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) _instruments = ("groq >= 0.9.0",) diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py index b90f4955f..3d677cca2 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py @@ -1,4 +1,5 @@ import logging +from enum import Enum from typing import Any, Iterable, Iterator, Mapping, Tuple, TypeVar from opentelemetry.util.types import AttributeValue @@ -9,6 +10,7 @@ MessageAttributes, OpenInferenceSpanKindValues, SpanAttributes, + ToolCallAttributes, ) __all__ = ("_RequestAttributesExtractor",) @@ -43,26 +45,73 @@ def get_extra_attributes_from_request( return invocation_params = dict(request_parameters) invocation_params.pop("messages", None) # Remove LLM input messages + invocation_params.pop("functions", None) + + if isinstance((tools := invocation_params.pop("tools", None)), Iterable): + for i, tool in enumerate(tools): + yield f"llm.tools.{i}.tool.json_schema", safe_json_dumps(tool) yield SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_params) if (input_messages := request_parameters.get("messages")) and isinstance( input_messages, Iterable ): - # Use reversed() to get the last message first. This is because OTEL has a default - # limit of 128 attributes per span, and flattening increases the number of - # attributes very quickly. for index, input_message in reversed(list(enumerate(input_messages))): - if role := input_message.get("role"): - yield ( - f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{MessageAttributes.MESSAGE_ROLE}", - role, - ) - if content := input_message.get("content"): + # Use reversed() to get the last message first. This is because OTEL has a default + # limit of 128 attributes per span, and flattening increases the number of + # attributes very quickly. + for key, value in self._get_attributes_from_message_param(input_message): + yield f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{key}", value + + def _get_attributes_from_message_param( + self, + message: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + if not hasattr(message, "get"): + return + if role := message.get("role"): + yield ( + MessageAttributes.MESSAGE_ROLE, + role.value if isinstance(role, Enum) else role, + ) + if content := message.get("content"): + yield ( + MessageAttributes.MESSAGE_CONTENT, + content, + ) + if name := message.get("name"): + yield MessageAttributes.MESSAGE_NAME, name + if (function_call := message.get("function_call")) and hasattr(function_call, "get"): + if function_name := function_call.get("name"): + yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, function_name + if function_arguments := function_call.get("arguments"): + yield ( + MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, + function_arguments, + ) + if (tool_calls := message.get("tool_calls")) and isinstance(tool_calls, Iterable): + for index, tool_call in enumerate(tool_calls): + if not hasattr(tool_call, "get"): + continue + if (tool_call_id := tool_call.get("id")) is not None: yield ( - f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{MessageAttributes.MESSAGE_CONTENT}", - content, + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_ID}", + tool_call_id, ) + if (function := tool_call.get("function")) and hasattr(function, "get"): + if name := function.get("name"): + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", + name, + ) + if arguments := function.get("arguments"): + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", + arguments, + ) T = TypeVar("T", bound=type) diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py new file mode 100644 index 000000000..e9599314e --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py @@ -0,0 +1,98 @@ +import logging +from typing import Any, Iterable, Iterator, Mapping, Tuple + +from opentelemetry.util.types import AttributeValue + +from openinference.instrumentation.groq._utils import _as_output_attributes, _io_value_and_type +from openinference.semconv.trace import MessageAttributes, SpanAttributes, ToolCallAttributes + +__all__ = ("_ResponseAttributesExtractor",) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class _ResponseAttributesExtractor: + __slots__ = () + + def get_attributes(self, response: Any) -> Iterator[Tuple[str, AttributeValue]]: + yield from _as_output_attributes( + _io_value_and_type(response), + ) + + def get_extra_attributes( + self, + response: Any, + request_parameters: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + yield from self._get_attributes_from_chat_completion( + completion=response, + request_parameters=request_parameters, + ) + + def _get_attributes_from_chat_completion( + self, + completion: Any, + request_parameters: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + if model := getattr(completion, "model", None): + yield SpanAttributes.LLM_MODEL_NAME, model + if usage := getattr(completion, "usage", None): + yield from self._get_attributes_from_completion_usage(usage) + if (choices := getattr(completion, "choices", None)) and isinstance(choices, Iterable): + for choice in choices: + if (index := getattr(choice, "index", None)) is None: + continue + if message := getattr(choice, "message", None): + for key, value in self._get_attributes_from_chat_completion_message(message): + yield f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{index}.{key}", value + + def _get_attributes_from_chat_completion_message( + self, + message: object, + ) -> Iterator[Tuple[str, AttributeValue]]: + if role := getattr(message, "role", None): + yield MessageAttributes.MESSAGE_ROLE, role + if content := getattr(message, "content", None): + yield MessageAttributes.MESSAGE_CONTENT, content + if function_call := getattr(message, "function_call", None): + if name := getattr(function_call, "name", None): + yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, name + if arguments := getattr(function_call, "arguments", None): + yield MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, arguments + if (tool_calls := getattr(message, "tool_calls", None)) and isinstance( + tool_calls, Iterable + ): + for index, tool_call in enumerate(tool_calls): + if (tool_call_id := getattr(tool_call, "id", None)) is not None: + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_ID}", + tool_call_id, + ) + if function := getattr(tool_call, "function", None): + if name := getattr(function, "name", None): + yield ( + ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}" + ), + name, + ) + if arguments := getattr(function, "arguments", None): + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", + arguments, + ) + + def _get_attributes_from_completion_usage( + self, + usage: object, + ) -> Iterator[Tuple[str, AttributeValue]]: + if (total_tokens := getattr(usage, "total_tokens", None)) is not None: + yield SpanAttributes.LLM_TOKEN_COUNT_TOTAL, total_tokens + if (prompt_tokens := getattr(usage, "prompt_tokens", None)) is not None: + yield SpanAttributes.LLM_TOKEN_COUNT_PROMPT, prompt_tokens + if (completion_tokens := getattr(usage, "completion_tokens", None)) is not None: + yield SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, completion_tokens diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py index 468c47701..39b9e825e 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py @@ -1,18 +1,12 @@ import logging import warnings -from typing import ( - Any, - Iterator, - Mapping, - NamedTuple, - Optional, - Sequence, - Tuple, -) - -from opentelemetry.util.types import AttributeValue +from typing import Any, Iterable, Iterator, Mapping, NamedTuple, Optional, Sequence, Tuple + +from opentelemetry import trace as trace_api +from opentelemetry.util.types import Attributes, AttributeValue from openinference.instrumentation import safe_json_dumps +from openinference.instrumentation.groq._with_span import _WithSpan from openinference.semconv.trace import OpenInferenceMimeTypeValues, SpanAttributes logger = logging.getLogger(__name__) @@ -55,3 +49,40 @@ def _as_input_attributes( # It's assumed to be TEXT by default, so we can skip to save one attribute. if value_and_type.type is not OpenInferenceMimeTypeValues.TEXT: yield SpanAttributes.INPUT_MIME_TYPE, value_and_type.type.value + + +def _as_output_attributes( + value_and_type: Optional[_ValueAndType], +) -> Iterator[Tuple[str, AttributeValue]]: + if not value_and_type: + return + yield SpanAttributes.OUTPUT_VALUE, value_and_type.value + # It's assumed to be TEXT by default, so we can skip to save one attribute. + if value_and_type.type is not OpenInferenceMimeTypeValues.TEXT: + yield SpanAttributes.OUTPUT_MIME_TYPE, value_and_type.type.value + + +def _finish_tracing( + with_span: _WithSpan, + attributes: Iterable[Tuple[str, AttributeValue]], + extra_attributes: Iterable[Tuple[str, AttributeValue]], + status: Optional[trace_api.Status] = None, +) -> None: + try: + attributes: Attributes = dict(attributes) + except Exception: + logger.exception("Failed to get attributes") + attributes = None + try: + extra_attributes: Attributes = dict(extra_attributes) + except Exception: + logger.exception("Failed to get extra attributes") + extra_attributes = None + try: + with_span.finish_tracing( + status=status, + attributes=attributes, + extra_attributes=extra_attributes, + ) + except Exception: + logger.exception("Failed to finish tracing") diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py index 70a8c9442..7dc4426d8 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py @@ -1,4 +1,5 @@ import json +import logging from abc import ABC from contextlib import contextmanager from enum import Enum @@ -14,15 +15,21 @@ from openinference.instrumentation.groq._request_attributes_extractor import ( _RequestAttributesExtractor, ) +from openinference.instrumentation.groq._response_attributes_extractor import ( + _ResponseAttributesExtractor, +) +from openinference.instrumentation.groq._utils import _finish_tracing from openinference.instrumentation.groq._with_span import _WithSpan from openinference.semconv.trace import ( EmbeddingAttributes, MessageAttributes, - OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes, ) +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + def _flatten(mapping: Mapping[str, Any]) -> Iterator[Tuple[str, AttributeValue]]: for key, value in mapping.items(): @@ -111,6 +118,7 @@ class _CompletionsWrapper(_WithTracer): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._request_extractor = _RequestAttributesExtractor() + self._response_extractor = _ResponseAttributesExtractor() def __call__( self, @@ -141,31 +149,26 @@ def __call__( try: response = wrapped(*args, **kwargs) except Exception as exception: - response = exception span.record_exception(exception) status = trace_api.Status( status_code=trace_api.StatusCode.ERROR, description=f"{type(exception).__name__}: {exception}", ) span.finish_tracing(status=status) - else: - content = response.choices[0].message.content - span.set_attributes( - dict( - _flatten( - { - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}": content, - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}": "assistant", - SpanAttributes.OUTPUT_VALUE: response.choices[0].message.content, - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON, - LLM_TOKEN_COUNT_COMPLETION: response.usage.completion_tokens, - SpanAttributes.LLM_TOKEN_COUNT_PROMPT: response.usage.prompt_tokens, - SpanAttributes.LLM_TOKEN_COUNT_TOTAL: response.usage.total_tokens, - } - ) - ) + raise + try: + _finish_tracing( + status=trace_api.Status(status_code=trace_api.StatusCode.OK), + with_span=span, + attributes=self._response_extractor.get_attributes(response=response), + extra_attributes=self._response_extractor.get_extra_attributes( + response=response, request_parameters=request_parameters + ), ) - span.finish_tracing(status=trace_api.Status(trace_api.StatusCode.OK)) + except Exception: + logger.exception(f"Failed to finalize response of type {type(response)}") + span.finish_tracing() + return response class _AsyncCompletionsWrapper(_WithTracer): @@ -177,6 +180,7 @@ class _AsyncCompletionsWrapper(_WithTracer): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._request_extractor = _RequestAttributesExtractor() + self._response_extractor = _ResponseAttributesExtractor() async def __call__( self, @@ -195,8 +199,6 @@ async def __call__( invocation_parameters.update(arg) invocation_parameters.update(kwargs) request_parameters = _parse_args(signature(wrapped), *args, **kwargs) - llm_invocation_params = kwargs - llm_messages = dict(kwargs).pop("messages", None) span_name = "AsyncCompletions" with self._start_as_current_span( @@ -207,25 +209,9 @@ async def __call__( request_parameters ), ) as span: - span.set_attributes( - dict( - _flatten( - { - SpanAttributes.OPENINFERENCE_SPAN_KIND: LLM, - SpanAttributes.LLM_INPUT_MESSAGES: llm_messages, - SpanAttributes.LLM_INVOCATION_PARAMETERS: safe_json_dumps( - llm_invocation_params - ), - SpanAttributes.LLM_MODEL_NAME: llm_invocation_params.get("model"), - SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON, - } - ) - ) - ) try: response = await wrapped(*args, **kwargs) except Exception as exception: - response = exception span.record_exception(exception) status = trace_api.Status( status_code=trace_api.StatusCode.ERROR, @@ -233,24 +219,18 @@ async def __call__( ) span.finish_tracing(status=status) raise - else: - content = response.choices[0].message.content - span.set_attributes( - dict( - _flatten( - { - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}": content, - f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}": "assistant", - SpanAttributes.OUTPUT_VALUE: response.choices[0].message.content, - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON, - LLM_TOKEN_COUNT_COMPLETION: response.usage.completion_tokens, - SpanAttributes.LLM_TOKEN_COUNT_PROMPT: response.usage.prompt_tokens, - SpanAttributes.LLM_TOKEN_COUNT_TOTAL: response.usage.total_tokens, - } - ) - ) + try: + _finish_tracing( + status=trace_api.Status(status_code=trace_api.StatusCode.OK), + with_span=span, + attributes=self._response_extractor.get_attributes(response=response), + extra_attributes=self._response_extractor.get_extra_attributes( + response=response, request_parameters=request_parameters + ), ) - span.finish_tracing(status=trace_api.Status(trace_api.StatusCode.OK)) + except Exception: + logger.exception(f"Failed to finalize response of type {type(response)}") + span.finish_tracing() return response diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py index d904bb8cb..cfda542a3 100644 --- a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py @@ -398,6 +398,7 @@ async def exec_comp() -> None: ) attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) + print(attributes) assert ( attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}"] == "user" @@ -475,11 +476,6 @@ async def exec_comp() -> None: assert invocation_params["model"] == "fake_model" assert invocation_params["temperature"] == 0.0 assert invocation_params["tool_choice"] == "required" - assert invocation_params["tools"][0] == { - "type": "function", - "function": "FunctionDefinition(name='hello_world', description=\"Print 'Hello world!'\", " - "parameters={'input': 'ex'})", - } def test_groq_uninstrumentation( diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py b/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py new file mode 100644 index 000000000..18c87143f --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py @@ -0,0 +1,245 @@ +import json +from typing import Any, Optional, Type, Union, cast + +import pytest +from groq import Groq +from groq._base_client import _StreamT +from groq._types import Body, RequestFiles, RequestOptions, ResponseT +from groq.types.chat import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionMessageToolCall, + ChatCompletionToolParam, +) +from groq.types.chat.chat_completion import Choice, CompletionUsage +from groq.types.chat.chat_completion_message_tool_call import Function +from opentelemetry import trace as trace_api +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + +def create_mock_tool_completion(messages): + last_user_message = next(msg for msg in messages[::-1] if msg.get("role") == "user") + city = last_user_message["content"].split(" in ")[-1].split("?")[0].strip() + + # Create tool calls with dynamically generated IDs + tool_calls = [ + ChatCompletionMessageToolCall( + id=f"call_{62136355 + i}", # Use a base ID and increment + function=Function(arguments=json.dumps({"city": city}), name=tool_name), + type="function", + ) + for i, tool_name in enumerate(["get_weather", "get_population"]) + ] + + return ChatCompletion( + id="chat_comp_0", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + logprobs=None, + message=ChatCompletionMessage( + content="", role="assistant", function_call=None, tool_calls=tool_calls + ), + ) + ], + created=1722531851, + model="fake_groq_model", + object="chat.completion", + system_fingerprint="fp0", + usage=CompletionUsage( + completion_tokens=379, + prompt_tokens=25, + total_tokens=404, + completion_time=0.616262398, + prompt_time=0.002549632, + queue_time=None, + total_time=0.6188120300000001, + ), + ) + + +def _mock_post( + self: Any, + path: str = "fake/url", + *, + cast_to: Type[ResponseT], + body: Optional[Body] = None, + options: RequestOptions = {}, + files: Optional[RequestFiles] = None, + stream: bool = False, + stream_cls: Optional[Type[_StreamT]] = None, +) -> Union[ResponseT, _StreamT]: + # Extract messages from the request body + messages = body.get("messages", []) if body else [] + + # Create a mock completion based on the messages + mock_completion = create_mock_tool_completion(messages) + + return cast(ResponseT, mock_completion) + + +@pytest.fixture() +def in_memory_span_exporter() -> InMemorySpanExporter: + return InMemorySpanExporter() + + +@pytest.fixture() +def tracer_provider(in_memory_span_exporter: InMemorySpanExporter) -> TracerProvider: + resource = Resource(attributes={}) + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(SimpleSpanProcessor(in_memory_span_exporter)) + return tracer_provider + + +@pytest.mark.vcr( + decode_compressed_response=True, + before_record_request=lambda _: _.headers.clear() or _, + before_record_response=lambda _: {**_, "headers": {}}, +) +def test_tool_calls( + in_memory_span_exporter: InMemorySpanExporter, + tracer_provider: trace_api.TracerProvider, +) -> None: + client = Groq(api_key="fake") + client.chat.completions._post = _mock_post # type: ignore[assignment] + + input_tools = [ + ChatCompletionToolParam( + type="function", + function={ + "name": "get_weather", + "description": "finds the weather for a given city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'London'", + } + }, + "required": ["city"], + }, + }, + ), + ChatCompletionToolParam( + type="function", + function={ + "name": "get_population", + "description": "finds the population for a given city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the population for, e.g. 'London'", + } + }, + "required": ["city"], + }, + }, + ), + ] + client.chat.completions.create( + extra_headers={"Accept-Encoding": "gzip"}, + model="fake_groq_model", + tools=input_tools, + messages=[ + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_62136355", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"city": "New York"}'}, + }, + { + "id": "call_62136356", + "type": "function", + "function": {"name": "get_population", "arguments": '{"city": "New York"}'}, + }, + ], + }, + { + "role": "tool", + "tool_call_id": "call_62136355", + "content": '{"city": "New York", "weather": "fine"}', + }, + { + "role": "tool", + "tool_call_id": "call_62136356", + "content": '{"city": "New York", "weather": "large"}', + }, + { + "role": "assistant", + "content": "In New York the weather is fine and the population is large.", + }, + { + "role": "user", + "content": "What's the weather and population in San Francisco?", + }, + ], + ) + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + attributes = dict(span.attributes or {}) + for i in range(len(input_tools)): + json_schema = attributes.pop(f"llm.tools.{i}.tool.json_schema") + assert isinstance(json_schema, str) + assert json.loads(json_schema) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.id") == "call_62136355" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.function.name") + == "get_weather" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.0.tool_call.function.arguments") + == '{"city": "New York"}' + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.id") == "call_62136356" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.function.name") + == "get_population" + ) + assert ( + attributes.pop("llm.input_messages.0.message.tool_calls.1.tool_call.function.arguments") + == '{"city": "New York"}' + ) + assert attributes.pop("llm.input_messages.1.message.role") == "tool" + assert attributes.pop("llm.input_messages.1.message.tool_call_id") == "call_62136355" + assert ( + attributes.pop("llm.input_messages.1.message.content") + == '{"city": "New York", "weather": "fine"}' + ) + assert attributes.pop("llm.input_messages.2.message.role") == "tool" + assert attributes.pop("llm.input_messages.2.message.tool_call_id") == "call_62136356" + assert ( + attributes.pop("llm.input_messages.2.message.content") + == '{"city": "New York", "weather": "large"}' + ) + assert attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.id") + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.function.name") + == "get_weather" + ) + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.0.tool_call.function.arguments") + == '{"city": "San Francisco"}' + ) + assert attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.id") + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.function.name") + == "get_population" + ) + assert ( + attributes.pop("llm.output_messages.0.message.tool_calls.1.tool_call.function.arguments") + == '{"city": "San Francisco"}' + ) From 6db5715f8a144354b90288bc3b737ebae40480c2 Mon Sep 17 00:00:00 2001 From: Chris Park Date: Wed, 11 Dec 2024 15:21:06 -0500 Subject: [PATCH 5/7] overhaul tests, fix groq input tool extraction --- .../examples/chat_completions_with_tool.py | 6 +- .../instrumentation/groq/__init__.py | 4 +- .../groq/_request_attributes_extractor.py | 49 +++- .../groq/_response_attributes_extractor.py | 2 + .../instrumentation/groq/_utils.py | 12 +- .../tests/conftest.py | 31 +++ .../tests/test_instrumentor.py | 205 +-------------- .../tests/test_tool_calls.py | 247 +++++++++++------- 8 files changed, 243 insertions(+), 313 deletions(-) create mode 100644 python/instrumentation/openinference-instrumentation-groq/tests/conftest.py diff --git a/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py b/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py index 5ab79754d..cd4f3bb32 100644 --- a/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py +++ b/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py @@ -2,7 +2,7 @@ import os from groq import AsyncGroq, Groq -from groq.types.chat import ChatCompletionToolParam +from groq.types.chat import ChatCompletionToolMessageParam from phoenix.otel import register from openinference.instrumentation.groq import GroqInstrumentor @@ -48,7 +48,7 @@ def test(): tool_call_id = tool_calls[0].id messages.append(message) messages.append( - ChatCompletionToolParam(content="sunny", role="tool", tool_call_id=tool_call_id), + ChatCompletionToolMessageParam(content="sunny", role="tool", tool_call_id=tool_call_id), ) final_response = client.chat.completions.create( model="mixtral-8x7b-32768", @@ -97,7 +97,7 @@ async def async_test(): tool_call_id = tool_calls[0].id messages.append(message) messages.append( - ChatCompletionToolParam(content="sunny", role="tool", tool_call_id=tool_call_id), + ChatCompletionToolMessageParam(content="sunny", role="tool", tool_call_id=tool_call_id), ) final_response = await client.chat.completions.create( model="mixtral-8x7b-32768", diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py index 858c9aa70..0989ade35 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/__init__.py @@ -3,9 +3,7 @@ from typing import Any, Collection from opentelemetry import trace as trace_api -from opentelemetry.instrumentation.instrumentor import ( - BaseInstrumentor, # type: ignore[attr-defined] -) +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore from wrapt import wrap_function_wrapper from groq.resources.chat.completions import AsyncCompletions, Completions diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py index 3d677cca2..035552ff1 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py @@ -1,9 +1,11 @@ import logging from enum import Enum -from typing import Any, Iterable, Iterator, Mapping, Tuple, TypeVar +from typing import Any, Dict, Iterable, Iterator, List, Mapping, Tuple, TypeVar from opentelemetry.util.types import AttributeValue +from groq.types.chat import ChatCompletionMessage, ChatCompletionMessageToolCall +from groq.types.chat.chat_completion_message_tool_call import Function from openinference.instrumentation import safe_json_dumps from openinference.instrumentation.groq._utils import _as_input_attributes, _io_value_and_type from openinference.semconv.trace import ( @@ -68,19 +70,29 @@ def _get_attributes_from_message_param( message: Mapping[str, Any], ) -> Iterator[Tuple[str, AttributeValue]]: if not hasattr(message, "get"): - return + if isinstance(message, ChatCompletionMessage): + message = self._cast_chat_completion_to_mapping(message) + else: + return if role := message.get("role"): yield ( MessageAttributes.MESSAGE_ROLE, role.value if isinstance(role, Enum) else role, ) + if content := message.get("content"): yield ( MessageAttributes.MESSAGE_CONTENT, content, ) + if name := message.get("name"): yield MessageAttributes.MESSAGE_NAME, name + + if tool_call_id := message.get("tool_call_id"): + yield MessageAttributes.MESSAGE_TOOL_CALL_ID, tool_call_id + + # Deprecated by Groq if (function_call := message.get("function_call")) and hasattr(function_call, "get"): if function_name := function_call.get("name"): yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, function_name @@ -89,7 +101,8 @@ def _get_attributes_from_message_param( MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, function_arguments, ) - if (tool_calls := message.get("tool_calls")) and isinstance(tool_calls, Iterable): + + if (tool_calls := message.get("tool_calls"),) and isinstance(tool_calls, Iterable): for index, tool_call in enumerate(tool_calls): if not hasattr(tool_call, "get"): continue @@ -113,6 +126,36 @@ def _get_attributes_from_message_param( arguments, ) + def _cast_chat_completion_to_mapping(self, message: ChatCompletionMessage) -> Mapping[str, Any]: + try: + casted_message = dict(message) + if (tool_calls := casted_message.get("tool_calls")) and isinstance( + tool_calls, Iterable + ): + casted_tool_calls: List[Dict[str, Any]] = [] + for tool_call in tool_calls: + if isinstance(tool_call, ChatCompletionMessageToolCall): + tool_call_dict = dict(tool_call) + + if (function := tool_call_dict.get("function")) and isinstance( + function, Function + ): + tool_call_dict["function"] = dict(function) + + casted_tool_calls.append(tool_call_dict) + else: + logger.debug(f"Skipping tool_call of unexpected type: {type(tool_call)}") + + casted_message["tool_calls"] = casted_tool_calls + + return casted_message + + except Exception as e: + logger.exception( + f"Failed to convert ChatCompletionMessage to mapping for {message}: {e}" + ) + return {} + T = TypeVar("T", bound=type) diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py index e9599314e..1f0bc642e 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py @@ -65,12 +65,14 @@ def _get_attributes_from_chat_completion_message( ): for index, tool_call in enumerate(tool_calls): if (tool_call_id := getattr(tool_call, "id", None)) is not None: + # https://github.com/groq/groq-python/blob/fa2e13b5747d18aeb478700f1e5426af2fd087a1/src/groq/types/chat/chat_completion_tool_message_param.py#L17 # noqa: E501 yield ( f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." f"{ToolCallAttributes.TOOL_CALL_ID}", tool_call_id, ) if function := getattr(tool_call, "function", None): + # https://github.com/groq/groq-python/blob/fa2e13b5747d18aeb478700f1e5426af2fd087a1/src/groq/types/chat/chat_completion_message_tool_call.py#L10 # noqa: E501 if name := getattr(function, "name", None): yield ( ( diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py index 39b9e825e..a15b3263c 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_utils.py @@ -3,7 +3,7 @@ from typing import Any, Iterable, Iterator, Mapping, NamedTuple, Optional, Sequence, Tuple from opentelemetry import trace as trace_api -from opentelemetry.util.types import Attributes, AttributeValue +from opentelemetry.util.types import AttributeValue from openinference.instrumentation import safe_json_dumps from openinference.instrumentation.groq._with_span import _WithSpan @@ -69,20 +69,18 @@ def _finish_tracing( status: Optional[trace_api.Status] = None, ) -> None: try: - attributes: Attributes = dict(attributes) + attributes_dict = dict(attributes) except Exception: logger.exception("Failed to get attributes") - attributes = None try: - extra_attributes: Attributes = dict(extra_attributes) + extra_attributes_dict = dict(extra_attributes) except Exception: logger.exception("Failed to get extra attributes") - extra_attributes = None try: with_span.finish_tracing( status=status, - attributes=attributes, - extra_attributes=extra_attributes, + attributes=attributes_dict, + extra_attributes=extra_attributes_dict, ) except Exception: logger.exception("Failed to finish tracing") diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/conftest.py b/python/instrumentation/openinference-instrumentation-groq/tests/conftest.py new file mode 100644 index 000000000..e7ca37256 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-groq/tests/conftest.py @@ -0,0 +1,31 @@ +from typing import Generator + +import pytest +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from openinference.instrumentation.groq import GroqInstrumentor + + +@pytest.fixture() +def in_memory_span_exporter() -> InMemorySpanExporter: + return InMemorySpanExporter() + + +@pytest.fixture() +def tracer_provider(in_memory_span_exporter: InMemorySpanExporter) -> TracerProvider: + resource = Resource(attributes={}) + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(SimpleSpanProcessor(in_memory_span_exporter)) + return tracer_provider + + +@pytest.fixture() +def setup_groq_instrumentation( + tracer_provider: TracerProvider, +) -> Generator[None, None, None]: + GroqInstrumentor().instrument(tracer_provider=tracer_provider) + yield + GroqInstrumentor().uninstrument() diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py index cfda542a3..3a1381ee3 100644 --- a/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-groq/tests/test_instrumentor.py @@ -1,27 +1,18 @@ import asyncio -import json -from typing import Any, Dict, Generator, List, Mapping, Optional, Type, Union, cast +from typing import Any, Dict, List, Mapping, Optional, Type, Union, cast import pytest from groq import AsyncGroq, Groq from groq._base_client import _StreamT from groq._types import Body, RequestFiles, RequestOptions, ResponseT from groq.resources.chat.completions import AsyncCompletions, Completions -from groq.types.chat import ChatCompletionToolParam from groq.types.chat.chat_completion import ( # type: ignore[attr-defined] ChatCompletion, ChatCompletionMessage, Choice, ) -from groq.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, - Function, -) from groq.types.completion_usage import CompletionUsage -from groq.types.shared import FunctionDefinition -from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.util.types import AttributeValue @@ -29,7 +20,7 @@ from openinference.instrumentation.groq import GroqInstrumentor from openinference.semconv.trace import MessageAttributes, SpanAttributes -mock_completion = ChatCompletion( +MOCK_COMPLETION = ChatCompletion( id="chat_comp_0", choices=[ Choice( @@ -56,49 +47,6 @@ ), ) -mock_tool_completion = ChatCompletion( - id="chat_comp_0", - choices=[ - Choice( - finish_reason="tool_calls", - index=0, - logprobs=None, - message=ChatCompletionMessage( - content="", - role="assistant", - function_call=None, - tool_calls=[ - ChatCompletionMessageToolCall( - id="call_t760", - function=Function(arguments="{}", name="hello_world"), - type="function", - ) - ], - ), - ) - ], - created=1722531851, - model="fake_model", - object="chat.completion", - system_fingerprint="fp0", - usage=CompletionUsage( - completion_tokens=379, - prompt_tokens=25, - total_tokens=404, - completion_time=0.616262398, - prompt_time=0.002549632, - queue_time=None, - total_time=0.6188120300000001, - ), -) - -test_tool = ChatCompletionToolParam( - type="function", - function=FunctionDefinition( - name="hello_world", description=("Print 'Hello world!'"), parameters={"input": "ex"} - ), # type: ignore -) - def _mock_post( self: Any, @@ -117,7 +65,7 @@ def _mock_post( ) return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)) """ - return cast(ResponseT, mock_completion) + return cast(ResponseT, MOCK_COMPLETION) async def _async_mock_post( @@ -137,7 +85,7 @@ async def _async_mock_post( ) return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)) """ - return cast(ResponseT, mock_completion) + return cast(ResponseT, MOCK_COMPLETION) @pytest.fixture() @@ -190,28 +138,6 @@ def prompt_template_variables() -> Dict[str, Any]: } -@pytest.fixture() -def in_memory_span_exporter() -> InMemorySpanExporter: - return InMemorySpanExporter() - - -@pytest.fixture() -def tracer_provider(in_memory_span_exporter: InMemorySpanExporter) -> TracerProvider: - resource = Resource(attributes={}) - tracer_provider = TracerProvider(resource=resource) - tracer_provider.add_span_processor(SimpleSpanProcessor(in_memory_span_exporter)) - return tracer_provider - - -@pytest.fixture() -def setup_groq_instrumentation( - tracer_provider: TracerProvider, -) -> Generator[None, None, None]: - GroqInstrumentor().instrument(tracer_provider=tracer_provider) - yield - GroqInstrumentor().uninstrument() - - def _check_context_attributes( attributes: Any, ) -> None: @@ -287,68 +213,6 @@ def test_groq_instrumentation( ) -def test_groq_tool_call( - tracer_provider: TracerProvider, - in_memory_span_exporter: InMemorySpanExporter, - setup_groq_instrumentation: Any, - session_id: str, - user_id: str, - metadata: Dict[str, Any], - tags: List[str], - prompt_template: str, - prompt_template_version: str, - prompt_template_variables: Dict[str, Any], -) -> None: - client = Groq(api_key="fake") - client.chat.completions._post = _mock_post # type: ignore[assignment] - - with using_attributes( - session_id=session_id, - user_id=user_id, - metadata=metadata, - tags=tags, - prompt_template=prompt_template, - prompt_template_version=prompt_template_version, - prompt_template_variables=prompt_template_variables, - ): - client.chat.completions.create( - messages=[ - { - "role": "user", - "content": "Print hello world", - } - ], - model="fake_model", - temperature=0.0, - tools=[test_tool], - tool_choice="required", - ) - spans = in_memory_span_exporter.get_finished_spans() - - assert spans[0].name == "Completions" - assert spans[0].attributes and spans[0].attributes.get("openinference.span.kind") == "LLM" - - for span in spans: - att = span.attributes - _check_context_attributes( - att, - ) - - attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) - - invocation_params = json.loads( - attributes.get(SpanAttributes.LLM_INVOCATION_PARAMETERS, "{}") # type: ignore - ) - assert invocation_params["model"] == "fake_model" - assert invocation_params["temperature"] == 0.0 - assert invocation_params["tool_choice"] == "required" - assert invocation_params["tools"][0] == { - "type": "function", - "function": "FunctionDefinition(name='hello_world', description=\"Print 'Hello world!'\", " - "parameters={'input': 'ex'})", - } - - def test_groq_async_instrumentation( tracer_provider: TracerProvider, in_memory_span_exporter: InMemorySpanExporter, @@ -417,67 +281,6 @@ async def exec_comp() -> None: ) -def test_groq_async_tool_call( - tracer_provider: TracerProvider, - in_memory_span_exporter: InMemorySpanExporter, - setup_groq_instrumentation: Any, - session_id: str, - user_id: str, - metadata: Dict[str, Any], - tags: List[str], - prompt_template: str, - prompt_template_version: str, - prompt_template_variables: Dict[str, Any], -) -> None: - client = AsyncGroq(api_key="fake") - client.chat.completions._post = _async_mock_post # type: ignore[assignment] - - async def exec_comp() -> None: - await client.chat.completions.create( - messages=[ - { - "role": "user", - "content": "Print hello world", - } - ], - model="fake_model", - temperature=0.0, - tools=[test_tool], - tool_choice="required", - ) - - with using_attributes( - session_id=session_id, - user_id=user_id, - metadata=metadata, - tags=tags, - prompt_template=prompt_template, - prompt_template_version=prompt_template_version, - prompt_template_variables=prompt_template_variables, - ): - asyncio.run(exec_comp()) - - spans = in_memory_span_exporter.get_finished_spans() - - assert spans[0].name == "AsyncCompletions" - assert spans[0].attributes and spans[0].attributes.get("openinference.span.kind") == "LLM" - - for span in spans: - att = span.attributes - _check_context_attributes( - att, - ) - - attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) - invocation_params = json.loads( - attributes.get(SpanAttributes.LLM_INVOCATION_PARAMETERS, "{}") # type: ignore - ) - - assert invocation_params["model"] == "fake_model" - assert invocation_params["temperature"] == 0.0 - assert invocation_params["tool_choice"] == "required" - - def test_groq_uninstrumentation( tracer_provider: TracerProvider, ) -> None: diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py b/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py index 18c87143f..4cc9a365d 100644 --- a/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py +++ b/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py @@ -1,40 +1,41 @@ import json -from typing import Any, Optional, Type, Union, cast +from typing import Any, Dict, List, Optional, Type, Union, cast import pytest from groq import Groq from groq._base_client import _StreamT from groq._types import Body, RequestFiles, RequestOptions, ResponseT -from groq.types.chat import ( - ChatCompletion, - ChatCompletionMessage, +from groq.types import CompletionUsage +from groq.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionToolParam +from groq.types.chat.chat_completion import Choice +from groq.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall, - ChatCompletionToolParam, + Function, ) -from groq.types.chat.chat_completion import Choice, CompletionUsage -from groq.types.chat.chat_completion_message_tool_call import Function -from opentelemetry import trace as trace_api -from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from openinference.instrumentation import using_attributes -def create_mock_tool_completion(messages): - last_user_message = next(msg for msg in messages[::-1] if msg.get("role") == "user") - city = last_user_message["content"].split(" in ")[-1].split("?")[0].strip() - # Create tool calls with dynamically generated IDs - tool_calls = [ - ChatCompletionMessageToolCall( - id=f"call_{62136355 + i}", # Use a base ID and increment - function=Function(arguments=json.dumps({"city": city}), name=tool_name), - type="function", - ) - for i, tool_name in enumerate(["get_weather", "get_population"]) - ] - - return ChatCompletion( +def _mock_post( + self: Any, + path: str = "fake/url", + *, + cast_to: Type[ResponseT], + body: Optional[Body] = None, + options: RequestOptions = {}, + files: Optional[RequestFiles] = None, + stream: bool = False, + stream_cls: Optional[Type[_StreamT]] = None, +) -> Union[ResponseT, _StreamT]: + """ + opts = FinalRequestOptions.construct( + method="post", url=path, json_data=body, files=to_httpx_files(files), **options + ) + return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)) + """ + mock_tool_completion = ChatCompletion( id="chat_comp_0", choices=[ Choice( @@ -42,12 +43,30 @@ def create_mock_tool_completion(messages): index=0, logprobs=None, message=ChatCompletionMessage( - content="", role="assistant", function_call=None, tool_calls=tool_calls + content="", + role="assistant", + function_call=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_62136357", + function=Function( + arguments='{"city": "San Francisco"}', name="get_weather" + ), + type="function", + ), + ChatCompletionMessageToolCall( + id="call_62136358", + function=Function( + arguments='{"city": "San Francisco"}', name="get_population" + ), + type="function", + ), + ], ), ) ], created=1722531851, - model="fake_groq_model", + model="test_groq_model", object="chat.completion", system_fingerprint="fp0", usage=CompletionUsage( @@ -60,51 +79,72 @@ def create_mock_tool_completion(messages): total_time=0.6188120300000001, ), ) + return cast(ResponseT, mock_tool_completion) -def _mock_post( - self: Any, - path: str = "fake/url", - *, - cast_to: Type[ResponseT], - body: Optional[Body] = None, - options: RequestOptions = {}, - files: Optional[RequestFiles] = None, - stream: bool = False, - stream_cls: Optional[Type[_StreamT]] = None, -) -> Union[ResponseT, _StreamT]: - # Extract messages from the request body - messages = body.get("messages", []) if body else [] +@pytest.fixture() +def session_id() -> str: + return "my-test-session-id" - # Create a mock completion based on the messages - mock_completion = create_mock_tool_completion(messages) - return cast(ResponseT, mock_completion) +@pytest.fixture() +def user_id() -> str: + return "my-test-user-id" @pytest.fixture() -def in_memory_span_exporter() -> InMemorySpanExporter: - return InMemorySpanExporter() +def metadata() -> Dict[str, Any]: + return { + "test-int": 1, + "test-str": "string", + "test-list": [1, 2, 3], + "test-dict": { + "key-1": "val-1", + "key-2": "val-2", + }, + } @pytest.fixture() -def tracer_provider(in_memory_span_exporter: InMemorySpanExporter) -> TracerProvider: - resource = Resource(attributes={}) - tracer_provider = TracerProvider(resource=resource) - tracer_provider.add_span_processor(SimpleSpanProcessor(in_memory_span_exporter)) - return tracer_provider +def tags() -> List[str]: + return ["tag-1", "tag-2"] + + +@pytest.fixture +def prompt_template() -> str: + return ( + "This is a test prompt template with int {var_int}, " + "string {var_string}, and list {var_list}" + ) + + +@pytest.fixture +def prompt_template_version() -> str: + return "v1.0" + + +@pytest.fixture +def prompt_template_variables() -> Dict[str, Any]: + return { + "var_int": 1, + "var_str": "2", + "var_list": [1, 2, 3], + } -@pytest.mark.vcr( - decode_compressed_response=True, - before_record_request=lambda _: _.headers.clear() or _, - before_record_response=lambda _: {**_, "headers": {}}, -) def test_tool_calls( + tracer_provider: TracerProvider, in_memory_span_exporter: InMemorySpanExporter, - tracer_provider: trace_api.TracerProvider, + setup_groq_instrumentation: Any, + session_id: str, + user_id: str, + metadata: Dict[str, Any], + tags: List[str], + prompt_template: str, + prompt_template_version: str, + prompt_template_variables: Dict[str, Any], ) -> None: - client = Groq(api_key="fake") + client = Groq(api_key="fake-api-key") client.chat.completions._post = _mock_post # type: ignore[assignment] input_tools = [ @@ -143,46 +183,61 @@ def test_tool_calls( }, ), ] - client.chat.completions.create( - extra_headers={"Accept-Encoding": "gzip"}, - model="fake_groq_model", - tools=input_tools, - messages=[ - { - "role": "assistant", - "tool_calls": [ - { - "id": "call_62136355", - "type": "function", - "function": {"name": "get_weather", "arguments": '{"city": "New York"}'}, - }, - { - "id": "call_62136356", - "type": "function", - "function": {"name": "get_population", "arguments": '{"city": "New York"}'}, - }, - ], - }, - { - "role": "tool", - "tool_call_id": "call_62136355", - "content": '{"city": "New York", "weather": "fine"}', - }, - { - "role": "tool", - "tool_call_id": "call_62136356", - "content": '{"city": "New York", "weather": "large"}', - }, - { - "role": "assistant", - "content": "In New York the weather is fine and the population is large.", - }, - { - "role": "user", - "content": "What's the weather and population in San Francisco?", - }, - ], - ) + + with using_attributes( + session_id=session_id, + user_id=user_id, + metadata=metadata, + tags=tags, + prompt_template=prompt_template, + prompt_template_version=prompt_template_version, + prompt_template_variables=prompt_template_variables, + ): + client.chat.completions.create( + model="test_groq_model", + tools=input_tools, + messages=[ + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_62136355", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "New York"}', + }, + }, + { + "id": "call_62136356", + "type": "function", + "function": { + "name": "get_population", + "arguments": '{"city": "New York"}', + }, + }, + ], + }, + { + "role": "tool", + "tool_call_id": "call_62136355", + "content": '{"city": "New York", "weather": "fine"}', + }, + { + "role": "tool", + "tool_call_id": "call_62136356", + "content": '{"city": "New York", "weather": "large"}', + }, + { + "role": "assistant", + "content": "In New York the weather is fine and the population is large.", + }, + { + "role": "user", + "content": "What's the weather and population in San Francisco?", + }, + ], + ) spans = in_memory_span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] From d580c585fc1df3b4e5fd6d22bf4bb12effefb0ae Mon Sep 17 00:00:00 2001 From: Chris Park Date: Fri, 20 Dec 2024 10:42:50 -0700 Subject: [PATCH 6/7] Update python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com> --- .../tests/test_tool_calls.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py b/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py index 4cc9a365d..f79635354 100644 --- a/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py +++ b/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py @@ -184,15 +184,6 @@ def test_tool_calls( ), ] - with using_attributes( - session_id=session_id, - user_id=user_id, - metadata=metadata, - tags=tags, - prompt_template=prompt_template, - prompt_template_version=prompt_template_version, - prompt_template_variables=prompt_template_variables, - ): client.chat.completions.create( model="test_groq_model", tools=input_tools, From 8ae68223f1d1d596e4d605cf07ba8c1ae3d75f40 Mon Sep 17 00:00:00 2001 From: Chris Park Date: Fri, 20 Dec 2024 15:27:35 -0700 Subject: [PATCH 7/7] update attr extraction, remove NOT_GIVEN values, fix tests --- .../examples/chat_completions_with_tool.py | 9 +- .../groq/_request_attributes_extractor.py | 75 +++++----------- .../groq/_response_attributes_extractor.py | 2 - .../instrumentation/groq/_wrappers.py | 5 +- .../tests/test_tool_calls.py | 88 +++++++++---------- 5 files changed, 74 insertions(+), 105 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py b/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py index cd4f3bb32..d4ed6ebc6 100644 --- a/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py +++ b/python/instrumentation/openinference-instrumentation-groq/examples/chat_completions_with_tool.py @@ -3,7 +3,9 @@ from groq import AsyncGroq, Groq from groq.types.chat import ChatCompletionToolMessageParam -from phoenix.otel import register +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.sdk.trace.export import SimpleSpanProcessor from openinference.instrumentation.groq import GroqInstrumentor @@ -107,7 +109,10 @@ async def async_test(): if __name__ == "__main__": - tracer_provider = register(project_name="groq_debug") + endpoint = "http://0.0.0.0:6006/v1/traces" + tracer_provider = trace_sdk.TracerProvider() + tracer_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint))) + GroqInstrumentor().instrument(tracer_provider=tracer_provider) response = test() diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py index 035552ff1..a00eb18e8 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_request_attributes_extractor.py @@ -1,11 +1,9 @@ import logging from enum import Enum -from typing import Any, Dict, Iterable, Iterator, List, Mapping, Tuple, TypeVar +from typing import Any, Iterable, Iterator, Mapping, Tuple, TypeVar from opentelemetry.util.types import AttributeValue -from groq.types.chat import ChatCompletionMessage, ChatCompletionMessageToolCall -from groq.types.chat.chat_completion_message_tool_call import Function from openinference.instrumentation import safe_json_dumps from openinference.instrumentation.groq._utils import _as_input_attributes, _io_value_and_type from openinference.semconv.trace import ( @@ -69,96 +67,65 @@ def _get_attributes_from_message_param( self, message: Mapping[str, Any], ) -> Iterator[Tuple[str, AttributeValue]]: - if not hasattr(message, "get"): - if isinstance(message, ChatCompletionMessage): - message = self._cast_chat_completion_to_mapping(message) - else: - return - if role := message.get("role"): + if role := get_attribute(message, "role"): yield ( MessageAttributes.MESSAGE_ROLE, role.value if isinstance(role, Enum) else role, ) - - if content := message.get("content"): + if content := get_attribute(message, "content"): yield ( MessageAttributes.MESSAGE_CONTENT, content, ) - - if name := message.get("name"): + if name := get_attribute(message, "name"): yield MessageAttributes.MESSAGE_NAME, name - if tool_call_id := message.get("tool_call_id"): + if tool_call_id := get_attribute(message, "tool_call_id"): yield MessageAttributes.MESSAGE_TOOL_CALL_ID, tool_call_id # Deprecated by Groq - if (function_call := message.get("function_call")) and hasattr(function_call, "get"): - if function_name := function_call.get("name"): + if function_call := get_attribute(message, "function_call"): + if function_name := get_attribute(function_call, "name"): yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, function_name - if function_arguments := function_call.get("arguments"): + if function_arguments := get_attribute(function_call, "arguments"): yield ( MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, function_arguments, ) - if (tool_calls := message.get("tool_calls"),) and isinstance(tool_calls, Iterable): + if (tool_calls := get_attribute(message, "tool_calls")) and isinstance( + tool_calls, Iterable + ): for index, tool_call in enumerate(tool_calls): - if not hasattr(tool_call, "get"): - continue - if (tool_call_id := tool_call.get("id")) is not None: + if (tool_call_id := get_attribute(tool_call, "id")) is not None: yield ( f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." f"{ToolCallAttributes.TOOL_CALL_ID}", tool_call_id, ) - if (function := tool_call.get("function")) and hasattr(function, "get"): - if name := function.get("name"): + if function := get_attribute(tool_call, "function"): + if name := get_attribute(function, "name"): yield ( f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", name, ) - if arguments := function.get("arguments"): + if arguments := get_attribute(function, "arguments"): yield ( f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", arguments, ) - def _cast_chat_completion_to_mapping(self, message: ChatCompletionMessage) -> Mapping[str, Any]: - try: - casted_message = dict(message) - if (tool_calls := casted_message.get("tool_calls")) and isinstance( - tool_calls, Iterable - ): - casted_tool_calls: List[Dict[str, Any]] = [] - for tool_call in tool_calls: - if isinstance(tool_call, ChatCompletionMessageToolCall): - tool_call_dict = dict(tool_call) - - if (function := tool_call_dict.get("function")) and isinstance( - function, Function - ): - tool_call_dict["function"] = dict(function) - - casted_tool_calls.append(tool_call_dict) - else: - logger.debug(f"Skipping tool_call of unexpected type: {type(tool_call)}") - - casted_message["tool_calls"] = casted_tool_calls - - return casted_message - - except Exception as e: - logger.exception( - f"Failed to convert ChatCompletionMessage to mapping for {message}: {e}" - ) - return {} - T = TypeVar("T", bound=type) def is_iterable_of(lst: Iterable[object], tp: T) -> bool: return isinstance(lst, Iterable) and all(isinstance(x, tp) for x in lst) + + +def get_attribute(obj: Any, attr_name: str, default: Any = None) -> Any: + if isinstance(obj, dict): + return obj.get(attr_name, default) + return getattr(obj, attr_name, default) diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py index 1f0bc642e..dddfb3f85 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_response_attributes_extractor.py @@ -13,8 +13,6 @@ class _ResponseAttributesExtractor: - __slots__ = () - def get_attributes(self, response: Any) -> Iterator[Tuple[str, AttributeValue]]: yield from _as_output_attributes( _io_value_and_type(response), diff --git a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py index 7dc4426d8..374428b8c 100644 --- a/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py +++ b/python/instrumentation/openinference-instrumentation-groq/src/openinference/instrumentation/groq/_wrappers.py @@ -11,6 +11,7 @@ from opentelemetry.trace import INVALID_SPAN from opentelemetry.util.types import AttributeValue +from groq import NOT_GIVEN from openinference.instrumentation import get_attributes_from_context, safe_json_dumps from openinference.instrumentation.groq._request_attributes_extractor import ( _RequestAttributesExtractor, @@ -93,11 +94,11 @@ def _parse_args( ) -> Dict[str, Any]: bound_signature = signature.bind(*args, **kwargs) bound_signature.apply_defaults() - bound_arguments = bound_signature.arguments + bound_arguments = bound_signature.arguments # Defaults empty to NOT_GIVEN request_data: Dict[str, Any] = {} for key, value in bound_arguments.items(): try: - if value is not None: + if value is not None and value is not NOT_GIVEN: try: # ensure the value is JSON-serializable safe_json_dumps(value) diff --git a/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py b/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py index f79635354..2b3b092fe 100644 --- a/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py +++ b/python/instrumentation/openinference-instrumentation-groq/tests/test_tool_calls.py @@ -15,8 +15,6 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from openinference.instrumentation import using_attributes - def _mock_post( self: Any, @@ -184,51 +182,51 @@ def test_tool_calls( ), ] - client.chat.completions.create( - model="test_groq_model", - tools=input_tools, - messages=[ - { - "role": "assistant", - "tool_calls": [ - { - "id": "call_62136355", - "type": "function", - "function": { - "name": "get_weather", - "arguments": '{"city": "New York"}', - }, + client.chat.completions.create( + model="test_groq_model", + tools=input_tools, + messages=[ + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_62136355", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "New York"}', }, - { - "id": "call_62136356", - "type": "function", - "function": { - "name": "get_population", - "arguments": '{"city": "New York"}', - }, + }, + { + "id": "call_62136356", + "type": "function", + "function": { + "name": "get_population", + "arguments": '{"city": "New York"}', }, - ], - }, - { - "role": "tool", - "tool_call_id": "call_62136355", - "content": '{"city": "New York", "weather": "fine"}', - }, - { - "role": "tool", - "tool_call_id": "call_62136356", - "content": '{"city": "New York", "weather": "large"}', - }, - { - "role": "assistant", - "content": "In New York the weather is fine and the population is large.", - }, - { - "role": "user", - "content": "What's the weather and population in San Francisco?", - }, - ], - ) + }, + ], + }, + { + "role": "tool", + "tool_call_id": "call_62136355", + "content": '{"city": "New York", "weather": "fine"}', + }, + { + "role": "tool", + "tool_call_id": "call_62136356", + "content": '{"city": "New York", "weather": "large"}', + }, + { + "role": "assistant", + "content": "In New York the weather is fine and the population is large.", + }, + { + "role": "user", + "content": "What's the weather and population in San Francisco?", + }, + ], + ) spans = in_memory_span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0]