Skip to content

Commit

Permalink
fix: add tool id for anthropic instrumentor and serialize content t…
Browse files Browse the repository at this point in the history
…o string if it's not a string (#1129)
  • Loading branch information
RogerHYang authored Nov 21, 2024
1 parent 8aa0012 commit 682724c
Showing 1 changed file with 26 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,11 @@ def _get_llm_input_messages(messages: List[Dict[str, str]]) -> Any:
elif isinstance(content, list):
for block in content:
if isinstance(block, ToolUseBlock):
if tool_call_id := block.id:
yield (
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_index}.{TOOL_CALL_ID}",
tool_call_id,
)
yield (
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_index}.{TOOL_CALL_FUNCTION_NAME}",
block.name,
Expand All @@ -376,6 +381,11 @@ def _get_llm_input_messages(messages: List[Dict[str, str]]) -> Any:
elif isinstance(block, dict):
block_type = block.get("type")
if block_type == "tool_use":
if (tool_call_id := block.get("id")) is not None:
yield (
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_index}.{TOOL_CALL_ID}",
str(tool_call_id),
)
yield (
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_index}.{TOOL_CALL_FUNCTION_NAME}",
block.get("name"),
Expand All @@ -385,12 +395,20 @@ def _get_llm_input_messages(messages: List[Dict[str, str]]) -> Any:
safe_json_dumps(block.get("input")),
)
tool_index += 1
if block_type == "tool_result":
yield (
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}",
block.get("content"),
)
if block_type == "text":
elif block_type == "tool_result":
if (tool_call_id := block.get("tool_use_id")) is not None:
yield (
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALL_ID}",
str(tool_call_id),
)
if (content := block.get("content")) is not None:
yield (
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}",
content
if isinstance(content, str)
else safe_json_dumps(content),
)
elif block_type == "text":
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", block.get("text")

if role := messages[i].get("role"):
Expand Down Expand Up @@ -489,13 +507,15 @@ def _validate_invocation_parameter(parameter: Any) -> bool:
MESSAGE_FUNCTION_CALL_NAME = MessageAttributes.MESSAGE_FUNCTION_CALL_NAME
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
MESSAGE_TOOL_CALL_ID = MessageAttributes.MESSAGE_TOOL_CALL_ID
METADATA = SpanAttributes.METADATA
OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS
SESSION_ID = SpanAttributes.SESSION_ID
TAG_TAGS = SpanAttributes.TAG_TAGS
TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
Expand Down

0 comments on commit 682724c

Please sign in to comment.