Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(js): OpenAI embeddings Instrumentation #34

Merged
merged 4 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions js/.changeset/old-apples-attack.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@arizeai/openinference-instrumentation-openai": patch
"@arizeai/openinference-semantic-conventions": patch
---

Add OpenAI Embeddings sementic attributes and instrumentation
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
module.exports = {
preset: "ts-jest",
testEnvironment: "node",
prettierPath: null,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jest doesn't support prettier 3

};
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ import {
ChatCompletionCreateParamsBase,
} from "openai/resources/chat/completions";
import { Stream } from "openai/streaming";
import {
CreateEmbeddingResponse,
EmbeddingCreateParams,
} from "openai/resources";

const MODULE_NAME = "openai";

Expand Down Expand Up @@ -80,7 +84,7 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
const span = instrumentation.tracer.startSpan(
`OpenAI Chat Completions`,
{
kind: SpanKind.CLIENT,
kind: SpanKind.INTERNAL,
attributes: {
[SemanticConventions.OPENINFERENCE_SPAN_KIND]:
OpenInferenceSpanKind.LLM,
Expand All @@ -106,6 +110,11 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
// Push the error to the span
if (error) {
span.recordException(error);
span.setStatus({
code: SpanStatusCode.ERROR,
message: error.message,
});
span.end();
}
},
);
Expand All @@ -115,6 +124,12 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
span.setAttributes({
[SemanticConventions.OUTPUT_VALUE]: JSON.stringify(result),
[SemanticConventions.OUTPUT_MIME_TYPE]: MimeType.JSON,
// Override the model from the value sent by the server
[SemanticConventions.LLM_MODEL_NAME]: isChatCompletionResponse(
result,
)
? result.model
: body.model,
...getLLMOutputMessagesAttributes(result),
...getUsageAttributes(result),
});
Expand All @@ -127,6 +142,75 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
};
},
);

// Patch embeddings
type EmbeddingsCreateType =
typeof module.OpenAI.Embeddings.prototype.create;
this._wrap(
module.OpenAI.Embeddings.prototype,
"create",
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(original: EmbeddingsCreateType): any => {
return function patchedEmbeddingCreate(
this: unknown,
...args: Parameters<typeof module.OpenAI.Embeddings.prototype.create>
) {
const body = args[0];
const { input } = body;
const isStringInput = typeof input == "string";
const span = instrumentation.tracer.startSpan(`OpenAI Embeddings`, {
kind: SpanKind.INTERNAL,
attributes: {
[SemanticConventions.OPENINFERENCE_SPAN_KIND]:
OpenInferenceSpanKind.EMBEDDING,
[SemanticConventions.EMBEDDING_MODEL_NAME]: body.model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably want to capture model from the response instead, because the one in the request can be part of invocation_parameters. Also, for Azure this is the deployment name, although last I checked, Azure is actually returning the wrong model names in embeddings response...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to double-check as a follow-up on Azure. It's a bit different in JS. They have a supplamentary package. I'll file a follow-up ticket for that.

[SemanticConventions.INPUT_VALUE]: isStringInput
? input
: JSON.stringify(input),
[SemanticConventions.INPUT_MIME_TYPE]: isStringInput
? MimeType.TEXT
: MimeType.JSON,
...getEmbeddingTextAttributes(body),
},
});
const execContext = trace.setSpan(context.active(), span);
const execPromise = safeExecuteInTheMiddle<
ReturnType<EmbeddingsCreateType>
>(
() => {
return context.with(execContext, () => {
return original.apply(this, args);
});
},
(error) => {
// Push the error to the span
if (error) {
span.recordException(error);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recording exceptions doesn't actually update the status, because a span can record multiple exceptions such as retry errors and still turn out OK.

Suggested change
span.recordException(error);
span.recordException(error);
span.setStatus({ code: SpanStatusCode.ERROR });
span.end();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I was thinking of leaving it mimimal - I'll at least set the status code. Might leave off the span end for now. Not sure if that's necessary. Will look into it a bit more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually I can put it below. I'll do that.

span.setStatus({
code: SpanStatusCode.ERROR,
message: error.message,
});
span.end();
}
},
);
const wrappedPromise = execPromise.then((result) => {
if (result) {
// Record the results
span.setAttributes({
// Do not record the output data as it can be large
...getEmbeddingEmbeddingsAttributes(result),
});
}
span.setStatus({ code: SpanStatusCode.OK });
span.end();
return result;
});
return context.bind(execContext, wrappedPromise);
};
},
);

module.openInferencePatched = true;
return module;
}
Expand All @@ -136,9 +220,19 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
private unpatch(moduleExports: typeof openai, moduleVersion?: string) {
diag.debug(`Removing patch for ${MODULE_NAME}@${moduleVersion}`);
this._unwrap(moduleExports.OpenAI.Chat.Completions.prototype, "create");
this._unwrap(moduleExports.OpenAI.Embeddings.prototype, "create");
}
}

/**
* type-guard that checks if the response is a chat completion response
*/
function isChatCompletionResponse(
response: Stream<ChatCompletionChunk> | ChatCompletion,
): response is ChatCompletion {
return "choices" in response;
}

/**
* Converts the body of the request to LLM input messages
*/
Expand Down Expand Up @@ -204,3 +298,43 @@ function getLLMOutputMessagesAttributes(
}
return {};
}

/**
* Converts the embedding result payload to embedding attributes
*/
function getEmbeddingTextAttributes(
request: EmbeddingCreateParams,
): Attributes {
if (typeof request.input == "string") {
return {
[`${SemanticConventions.EMBEDDING_EMBEDDINGS}.0.${SemanticConventions.EMBEDDING_TEXT}`]:
request.input,
};
} else if (
Array.isArray(request.input) &&
request.input.length > 0 &&
typeof request.input[0] == "string"
) {
return request.input.reduce((acc, input, index) => {
const index_prefix = `${SemanticConventions.EMBEDDING_EMBEDDINGS}.${index}`;
acc[`${index_prefix}.${SemanticConventions.EMBEDDING_TEXT}`] = input;
return acc;
}, {} as Attributes);
}
// Ignore other cases where input is a number or an array of numbers
return {};
}

/**
* Converts the embedding result payload to embedding attributes
*/
function getEmbeddingEmbeddingsAttributes(
response: CreateEmbeddingResponse,
): Attributes {
return response.data.reduce((acc, embedding, index) => {
const index_prefix = `${SemanticConventions.EMBEDDING_EMBEDDINGS}.${index}`;
acc[`${index_prefix}.${SemanticConventions.EMBEDDING_VECTOR}`] =
embedding.embedding;
return acc;
}, {} as Attributes);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,22 @@ describe("OpenAIInstrumentation", () => {

instrumentation.setTracerProvider(tracerProvider);
tracerProvider.addSpanProcessor(new SimpleSpanProcessor(memoryExporter));
// @ts-expect-error the moduleExports property is private. This is needed to make the test work with auto-mocking
instrumentation._modules[0].moduleExports = OpenAI;

beforeEach(() => {
// @ts-expect-error the moduleExports property is private. This is needed to make the test work with auto-mocking
instrumentation._modules[0].moduleExports = OpenAI;
beforeAll(() => {
instrumentation.enable();
openai = new OpenAI.OpenAI({
apiKey: `fake-api-key`,
});
});
afterAll(() => {
instrumentation.disable();
});
beforeEach(() => {
memoryExporter.reset();
});
afterEach(() => {
instrumentation.disable();
jest.clearAllMocks();
});
it("is patched", () => {
Expand Down Expand Up @@ -85,7 +89,7 @@ describe("OpenAIInstrumentation", () => {
"llm.input_messages.0.message.content": "Say this is a test",
"llm.input_messages.0.message.role": "user",
"llm.invocation_parameters": "{"model":"gpt-3.5-turbo"}",
"llm.model_name": "gpt-3.5-turbo",
"llm.model_name": "gpt-3.5-turbo-0613",
"llm.output_messages.0.message.content": "This is a test.",
"llm.output_messages.0.message.role": "assistant",
"llm.token_count.completion": 5,
Expand All @@ -97,4 +101,39 @@ describe("OpenAIInstrumentation", () => {
}
`);
});
it("creates a span for embedding create", async () => {
const response = {
object: "list",
data: [{ object: "embedding", index: 0, embedding: [1, 2, 3] }],
};
// Mock out the embedding create endpoint
jest.spyOn(openai, "post").mockImplementation(
// @ts-expect-error the response type is not correct - this is just for testing
async (): Promise<unknown> => {
return response;
},
);
await openai.embeddings.create({
input: "A happy moment",
model: "text-embedding-ada-002",
});
const spans = memoryExporter.getFinishedSpans();
expect(spans.length).toBe(1);
const span = spans[0];
expect(span.name).toBe("OpenAI Embeddings");
expect(span.attributes).toMatchInlineSnapshot(`
{
"embedding.embeddings.0.embedding.text": "A happy moment",
"embedding.embeddings.0.embedding.vector": [
1,
2,
3,
],
"embedding.model_name": "text-embedding-ada-002",
"input.mime_type": "text/plain",
"input.value": "A happy moment",
"openinference.span.kind": "embedding",
}
`);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ export const EMBEDDING_MODEL_NAME =
export const EMBEDDING_VECTOR =
`${SemanticAttributePrefixes.embedding}.${EmbeddingAttributePostfixes.vector}` as const;

/**
* The embedding list root
*/
export const EMBEDDING_EMBEDDINGS =
`${SemanticAttributePrefixes.embedding}.${EmbeddingAttributePostfixes.embeddings}` as const;

export const SemanticConventions = {
INPUT_VALUE,
INPUT_MIME_TYPE,
Expand All @@ -234,6 +240,7 @@ export const SemanticConventions = {
DOCUMENT_CONTENT,
DOCUMENT_SCORE,
DOCUMENT_METADATA,
EMBEDDING_EMBEDDINGS,
EMBEDDING_TEXT,
EMBEDDING_MODEL_NAME,
EMBEDDING_VECTOR,
Expand Down