Skip to content

Commit

Permalink
feat(js): openai tool calls (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeldking authored Jan 16, 2024
1 parent c6832d5 commit 2803ee0
Show file tree
Hide file tree
Showing 4 changed files with 393 additions and 17 deletions.
118 changes: 102 additions & 16 deletions js/packages/openinference-instrumentation-openai/src/instrumentation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import {
ChatCompletion,
ChatCompletionChunk,
ChatCompletionCreateParamsBase,
ChatCompletionMessage,
ChatCompletionMessageParam,
} from "openai/resources/chat/completions";
import { CompletionCreateParamsBase } from "openai/resources/completions";
import { Stream } from "openai/streaming";
Expand All @@ -33,6 +35,7 @@ import {
CreateEmbeddingResponse,
EmbeddingCreateParams,
} from "openai/resources";
import { assertUnreachable } from "./typeUtils";

const MODULE_NAME = "openai";

Expand Down Expand Up @@ -334,17 +337,64 @@ function getLLMInputMessagesAttributes(
body: ChatCompletionCreateParamsBase,
): Attributes {
return body.messages.reduce((acc, message, index) => {
const index_prefix = `${SemanticConventions.LLM_INPUT_MESSAGES}.${index}`;
acc[`${index_prefix}.${SemanticConventions.MESSAGE_CONTENT}`] = String(
message.content,
);
acc[`${index_prefix}.${SemanticConventions.MESSAGE_ROLE}`] = String(
message.role,
);
const messageAttributes = getChatCompletionInputMessageAttributes(message);
const indexPrefix = `${SemanticConventions.LLM_INPUT_MESSAGES}.${index}.`;
// Flatten the attributes on the index prefix
for (const [key, value] of Object.entries(messageAttributes)) {
acc[`${indexPrefix}${key}`] = value;
}
return acc;
}, {} as Attributes);
}

function getChatCompletionInputMessageAttributes(
message: ChatCompletionMessageParam,
): Attributes {
const role = message.role;
const attributes: Attributes = {
[SemanticConventions.MESSAGE_ROLE]: role,
};
// Add the content only if it is a string
if (typeof message.content === "string")
attributes[SemanticConventions.MESSAGE_CONTENT] = message.content;
switch (role) {
case "user":
// There's nothing to add for the user
break;
case "assistant":
if (message.tool_calls) {
message.tool_calls.forEach((toolCall, index) => {
// Make sure the tool call has a function
if (toolCall.function) {
const toolCallIndexPrefix = `${SemanticConventions.MESSAGE_TOOL_CALLS}.${index}.`;
attributes[
toolCallIndexPrefix + SemanticConventions.TOOL_CALL_FUNCTION_NAME
] = toolCall.function.name;
attributes[
toolCallIndexPrefix +
SemanticConventions.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
] = toolCall.function.arguments;
}
});
}
break;
case "function":
attributes[SemanticConventions.MESSAGE_FUNCTION_CALL_NAME] = message.name;
break;
case "tool":
// There's nothing to add for the tool. There is a tool_id, but there are no
// semantic conventions for it
break;
case "system":
// There's nothing to add for the system. Content is captured above
break;
default:
assertUnreachable(role);
break;
}
return attributes;
}

/**
* Converts the body of a completions request to input attributes
*/
Expand Down Expand Up @@ -401,15 +451,51 @@ function getChatCompletionLLMOutputMessagesAttributes(
return {};
}
return [choice.message].reduce((acc, message, index) => {
const indexPrefix = `${SemanticConventions.LLM_OUTPUT_MESSAGES}.${index}`;
acc[`${indexPrefix}.${SemanticConventions.MESSAGE_CONTENT}`] = String(
message.content,
);
acc[`${indexPrefix}.${SemanticConventions.MESSAGE_ROLE}`] = message.role;
const indexPrefix = `${SemanticConventions.LLM_OUTPUT_MESSAGES}.${index}.`;
const messageAttributes = getChatCompletionOutputMessageAttributes(message);
// Flatten the attributes on the index prefix
for (const [key, value] of Object.entries(messageAttributes)) {
acc[`${indexPrefix}${key}`] = value;
}
return acc;
}, {} as Attributes);
}

function getChatCompletionOutputMessageAttributes(
message: ChatCompletionMessage,
): Attributes {
const role = message.role;
const attributes: Attributes = {
[SemanticConventions.MESSAGE_ROLE]: role,
};
if (message.content) {
attributes[SemanticConventions.MESSAGE_CONTENT] = message.content;
}
if (message.tool_calls) {
message.tool_calls.forEach((toolCall, index) => {
const toolCallIndexPrefix = `${SemanticConventions.MESSAGE_TOOL_CALLS}.${index}.`;
// Double check that the tool call has a function
// NB: OpenAI only supports tool calls with functions right now but this may change
if (toolCall.function) {
attributes[
toolCallIndexPrefix + SemanticConventions.TOOL_CALL_FUNCTION_NAME
] = toolCall.function.name;
attributes[
toolCallIndexPrefix +
SemanticConventions.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
] = toolCall.function.arguments;
}
});
}
if (message.function_call) {
attributes[SemanticConventions.MESSAGE_FUNCTION_CALL_NAME] =
message.function_call.name;
attributes[SemanticConventions.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON] =
message.function_call.arguments;
}
return attributes;
}

/**
* Converts the completion result to output attributes
*/
Expand Down Expand Up @@ -444,8 +530,8 @@ function getEmbeddingTextAttributes(
typeof request.input[0] === "string"
) {
return request.input.reduce((acc, input, index) => {
const index_prefix = `${SemanticConventions.EMBEDDING_EMBEDDINGS}.${index}`;
acc[`${index_prefix}.${SemanticConventions.EMBEDDING_TEXT}`] = input;
const indexPrefix = `${SemanticConventions.EMBEDDING_EMBEDDINGS}.${index}.`;
acc[`${indexPrefix}${SemanticConventions.EMBEDDING_TEXT}`] = input;
return acc;
}, {} as Attributes);
}
Expand All @@ -460,8 +546,8 @@ function getEmbeddingEmbeddingsAttributes(
response: CreateEmbeddingResponse,
): Attributes {
return response.data.reduce((acc, embedding, index) => {
const index_prefix = `${SemanticConventions.EMBEDDING_EMBEDDINGS}.${index}`;
acc[`${index_prefix}.${SemanticConventions.EMBEDDING_VECTOR}`] =
const indexPrefix = `${SemanticConventions.EMBEDDING_EMBEDDINGS}.${index}.`;
acc[`${indexPrefix}${SemanticConventions.EMBEDDING_VECTOR}`] =
embedding.embedding;
return acc;
}, {} as Attributes);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/**
* Utility function that uses the type system to check if a switch statement is exhaustive.
* If the switch statement is not exhaustive, there will be a type error caught in typescript
*
* See https://stackoverflow.com/questions/39419170/how-do-i-check-that-a-switch-block-is-exhaustive-in-typescript for more details.
*/
export function assertUnreachable(_: never): never {
throw new Error("Unreachable");
}
Loading

0 comments on commit 2803ee0

Please sign in to comment.