From 682724ce436ef8ece5d821073e3845cc3a9d602d Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Thu, 21 Nov 2024 12:54:42 -0800 Subject: [PATCH] fix: add tool id for anthropic instrumentor and serialize `content` to string if it's not a string (#1129) --- .../instrumentation/anthropic/_wrappers.py | 32 +++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py index fa3129a87..191b96dd6 100644 --- a/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py +++ b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py @@ -362,6 +362,11 @@ def _get_llm_input_messages(messages: List[Dict[str, str]]) -> Any: elif isinstance(content, list): for block in content: if isinstance(block, ToolUseBlock): + if tool_call_id := block.id: + yield ( + f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_index}.{TOOL_CALL_ID}", + tool_call_id, + ) yield ( f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_index}.{TOOL_CALL_FUNCTION_NAME}", block.name, @@ -376,6 +381,11 @@ def _get_llm_input_messages(messages: List[Dict[str, str]]) -> Any: elif isinstance(block, dict): block_type = block.get("type") if block_type == "tool_use": + if (tool_call_id := block.get("id")) is not None: + yield ( + f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_index}.{TOOL_CALL_ID}", + str(tool_call_id), + ) yield ( f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_index}.{TOOL_CALL_FUNCTION_NAME}", block.get("name"), @@ -385,12 +395,20 @@ def _get_llm_input_messages(messages: List[Dict[str, str]]) -> Any: safe_json_dumps(block.get("input")), ) tool_index += 1 - if block_type == "tool_result": - yield ( - f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", - block.get("content"), - ) - if block_type == "text": + elif block_type == "tool_result": + if (tool_call_id := block.get("tool_use_id")) is not None: + yield ( + f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALL_ID}", + str(tool_call_id), + ) + if (content := block.get("content")) is not None: + yield ( + f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", + content + if isinstance(content, str) + else safe_json_dumps(content), + ) + elif block_type == "text": yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", block.get("text") if role := messages[i].get("role"): @@ -489,6 +507,7 @@ def _validate_invocation_parameter(parameter: Any) -> bool: MESSAGE_FUNCTION_CALL_NAME = MessageAttributes.MESSAGE_FUNCTION_CALL_NAME MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS +MESSAGE_TOOL_CALL_ID = MessageAttributes.MESSAGE_TOOL_CALL_ID METADATA = SpanAttributes.METADATA OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE @@ -496,6 +515,7 @@ def _validate_invocation_parameter(parameter: Any) -> bool: RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS SESSION_ID = SpanAttributes.SESSION_ID TAG_TAGS = SpanAttributes.TAG_TAGS +TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA