Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(langchain-js): add token counts for ChatBedrock models #1123

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions js/.changeset/mean-walls-talk.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@arizeai/openinference-instrumentation-langchain": patch
---

fix: add support for capturing ChatBedrock token counts
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,43 @@ function getTokenCount(maybeCount: unknown) {
* Formats the token counts of a langchain run into OpenInference attributes.
* @param outputs - The outputs of a langchain run
* @returns The OpenInference attributes for the token counts
*
* @see https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/language_models/chat_models.ts#L403 for how token counts get added to outputs
*/
function formatTokenCounts(
outputs: Run["outputs"],
): TokenCountAttributes | null {
if (!isObject(outputs)) {
return null;
}
const firstGeneration = getFirstOutputGeneration(outputs);

/**
* Some community models have non standard output structures and show token counts in different places notable ChatBedrock
* @see https://github.com/langchain-ai/langchainjs/blob/a173e300ef9ada416220876a2739e024b3a7f268/libs/langchain-community/src/chat_models/bedrock/web.ts
*/
// Generations is an array of arrays containing messages
const maybeGenerationComponent =
firstGeneration != null ? firstGeneration[0] : null;
const maybeMessage = isObject(maybeGenerationComponent)
? maybeGenerationComponent.message
: null;
const usageMetadata = isObject(maybeMessage)
? maybeMessage.usage_metadata
: null;
if (isObject(usageMetadata)) {
return {
[SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]: getTokenCount(
usageMetadata.output_tokens,
),
[SemanticConventions.LLM_TOKEN_COUNT_PROMPT]: getTokenCount(
usageMetadata.input_tokens,
),
[SemanticConventions.LLM_TOKEN_COUNT_TOTAL]: getTokenCount(
usageMetadata.total_tokens,
),
};
}
const llmOutput = outputs.llmOutput;
if (!isObject(llmOutput)) {
return null;
Expand Down Expand Up @@ -519,6 +549,33 @@ function formatTokenCounts(
),
};
}

/**
* In some cases community models have a different output structure do to the way they extend the base model
* Notably ChatBedrock may have tokens stored in this format instead of normalized
* @see https://github.com/langchain-ai/langchainjs/blob/a173e300ef9ada416220876a2739e024b3a7f268/libs/langchain-community/src/chat_models/bedrock/web.ts for ChatBedrock
* and
* @see https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/language_models/chat_models.ts#L403 for nomalization
*/
if (isObject(llmOutput.usage)) {
const maybePromptTokens = getTokenCount(llmOutput.usage.input_tokens);
const maybeCompletionTokens = getTokenCount(llmOutput.usage.output_tokens);
let maybeTotalTokens = getTokenCount(llmOutput.usage.total_tokens);
if (maybeTotalTokens == null) {
maybeTotalTokens =
isNumber(maybePromptTokens) && isNumber(maybeCompletionTokens)
? maybePromptTokens + maybeCompletionTokens
: undefined;
}
return {
[SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]: getTokenCount(
maybeCompletionTokens,
),
[SemanticConventions.LLM_TOKEN_COUNT_PROMPT]:
getTokenCount(maybePromptTokens),
[SemanticConventions.LLM_TOKEN_COUNT_TOTAL]: maybeTotalTokens,
};
}
return null;
}

Expand Down Expand Up @@ -589,7 +646,7 @@ function formatMetadata(run: Run) {
return null;
}
return {
metadata: safelyJSONStringify(run.extra.metadata),
[SemanticConventions.METADATA]: safelyJSONStringify(run.extra.metadata),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,66 @@ describe("formatTokenCounts", () => {
[SemanticConventions.LLM_TOKEN_COUNT_TOTAL]: 30,
});
});

it("should add token counts from generations message usage metadata if present", () => {
const outputs = {
generations: [
[
{
message: {
usage_metadata: {
input_tokens: 20,
output_tokens: 10,
total_tokens: 30,
},
},
},
],
],
};
const result = safelyFormatTokenCounts(outputs);
expect(result).toEqual({
[SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]: 10,
[SemanticConventions.LLM_TOKEN_COUNT_PROMPT]: 20,
[SemanticConventions.LLM_TOKEN_COUNT_TOTAL]: 30,
});
});

it("should add token counts from llmOutput.usage if present adding together input and output tokens if total tokens is not present", () => {
const outputs = {
llmOutput: {
usage: {
input_tokens: 20,
output_tokens: 10,
},
},
};
const result = safelyFormatTokenCounts(outputs);
expect(result).toEqual({
[SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]: 10,
[SemanticConventions.LLM_TOKEN_COUNT_PROMPT]: 20,
[SemanticConventions.LLM_TOKEN_COUNT_TOTAL]: 30,
});
});

it("should add token counts from llmOutput.usage if present using total_tokens if present", () => {
const outputs = {
llmOutput: {
usage: {
input_tokens: 20,
output_tokens: 10,
// incorrect total tokens to show that this number is prioritized over addition
total_tokens: 35,
},
},
};
const result = safelyFormatTokenCounts(outputs);
expect(result).toEqual({
[SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]: 10,
[SemanticConventions.LLM_TOKEN_COUNT_PROMPT]: 20,
[SemanticConventions.LLM_TOKEN_COUNT_TOTAL]: 35,
});
});
});

describe("formatFunctionCalls", () => {
Expand Down
Loading