From 5dfa9743e05ce001dd1238a677e2c0a3d5530d70 Mon Sep 17 00:00:00 2001 From: Chris Park Date: Thu, 8 Aug 2024 17:02:07 -0700 Subject: [PATCH] feat(js): llama-index-ts embeddings support (#761) Co-authored-by: Parker Stafford Co-authored-by: Mikyo King --- .../package.json | 4 +- .../src/instrumentation.ts | 47 +++- .../src/types.ts | 62 +---- .../src/utils.ts | 247 +++++++++++++----- .../test/llamaIndex.test.ts | 226 ++++++++++++---- .../openinference-vercel/tsconfig.json | 2 +- 6 files changed, 408 insertions(+), 180 deletions(-) diff --git a/js/packages/openinference-instrumentation-llama-index/package.json b/js/packages/openinference-instrumentation-llama-index/package.json index 3ef330a8b..abdefcf54 100644 --- a/js/packages/openinference-instrumentation-llama-index/package.json +++ b/js/packages/openinference-instrumentation-llama-index/package.json @@ -24,6 +24,8 @@ "@opentelemetry/instrumentation": "^0.46.0" }, "devDependencies": { - "llamaindex": "^0.3.14" + "jest": "^29.7.0", + "llamaindex": "^0.3.14", + "openai": "^4.24.1" } } diff --git a/js/packages/openinference-instrumentation-llama-index/src/instrumentation.ts b/js/packages/openinference-instrumentation-llama-index/src/instrumentation.ts index af73f1541..d2f62e452 100644 --- a/js/packages/openinference-instrumentation-llama-index/src/instrumentation.ts +++ b/js/packages/openinference-instrumentation-llama-index/src/instrumentation.ts @@ -7,7 +7,13 @@ import { InstrumentationNodeModuleDefinition, } from "@opentelemetry/instrumentation"; import { diag } from "@opentelemetry/api"; -import { patchQueryMethod, patchRetrieveMethod } from "./utils"; +import { + patchQueryEngineQueryMethod, + patchRetrieveMethod, + patchQueryEmbeddingMethod, + isRetrieverPrototype, + isEmbeddingPrototype, +} from "./utils"; import { VERSION } from "./version"; const MODULE_NAME = "llamaindex"; @@ -58,23 +64,32 @@ export class LlamaIndexInstrumentation extends InstrumentationBase< } // TODO: Support streaming + // TODO: Generalize to QueryEngine interface (RetrieverQueryEngine, RouterQueryEngine) this._wrap( moduleExports.RetrieverQueryEngine.prototype, "query", // eslint-disable-next-line @typescript-eslint/no-explicit-any (original): any => { - return patchQueryMethod(original, moduleExports, this.tracer); + return patchQueryEngineQueryMethod(original, this.tracer); }, ); - this._wrap( - moduleExports.VectorIndexRetriever.prototype, - "retrieve", - (original) => { - return patchRetrieveMethod(original, moduleExports, this.tracer); - }, - ); + for (const value of Object.values(moduleExports)) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const prototype = (value as any).prototype; + + if (isRetrieverPrototype(prototype)) { + this._wrap(prototype, "retrieve", (original) => { + return patchRetrieveMethod(original, this.tracer); + }); + } + if (isEmbeddingPrototype(prototype)) { + this._wrap(prototype, "getQueryEmbedding", (original) => { + return patchQueryEmbeddingMethod(original, this.tracer); + }); + } + } _isOpenInferencePatched = true; return moduleExports; } @@ -82,7 +97,19 @@ export class LlamaIndexInstrumentation extends InstrumentationBase< private unpatch(moduleExports: typeof llamaindex, moduleVersion?: string) { this._diag.debug(`Un-patching ${MODULE_NAME}@${moduleVersion}`); this._unwrap(moduleExports.RetrieverQueryEngine.prototype, "query"); - this._unwrap(moduleExports.VectorIndexRetriever.prototype, "retrieve"); + + for (const value of Object.values(moduleExports)) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const prototype = (value as any).prototype; + + if (isRetrieverPrototype(prototype)) { + this._unwrap(prototype, "retrieve"); + } + + if (isEmbeddingPrototype(prototype)) { + this._unwrap(prototype, "getQueryEmbedding"); + } + } _isOpenInferencePatched = false; } diff --git a/js/packages/openinference-instrumentation-llama-index/src/types.ts b/js/packages/openinference-instrumentation-llama-index/src/types.ts index faafbff0a..3ba845ad0 100644 --- a/js/packages/openinference-instrumentation-llama-index/src/types.ts +++ b/js/packages/openinference-instrumentation-llama-index/src/types.ts @@ -1,39 +1,5 @@ -import { SemanticConventions } from "@arizeai/openinference-semantic-conventions"; - -export type RetrievalDocument = { - [SemanticConventions.DOCUMENT_ID]?: string; - [SemanticConventions.DOCUMENT_CONTENT]?: string; - [SemanticConventions.DOCUMENT_SCORE]?: number | undefined; - [SemanticConventions.DOCUMENT_METADATA]?: string; -}; - -type LLMMessageToolCall = { - [SemanticConventions.TOOL_CALL_FUNCTION_NAME]?: string; - [SemanticConventions.TOOL_CALL_FUNCTION_ARGUMENTS_JSON]?: string; -}; - -export type LLMMessageToolCalls = { - [SemanticConventions.MESSAGE_TOOL_CALLS]?: LLMMessageToolCall[]; -}; - -export type LLMMessageFunctionCall = { - [SemanticConventions.MESSAGE_FUNCTION_CALL_NAME]?: string; - [SemanticConventions.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON]?: string; -}; - -export type LLMMessage = LLMMessageToolCalls & - LLMMessageFunctionCall & { - [SemanticConventions.MESSAGE_ROLE]?: string; - [SemanticConventions.MESSAGE_CONTENT]?: string; - }; - -export type LLMMessagesAttributes = - | { - [SemanticConventions.LLM_INPUT_MESSAGES]: LLMMessage[]; - } - | { - [SemanticConventions.LLM_OUTPUT_MESSAGES]: LLMMessage[]; - }; +import * as llamaindex from "llamaindex"; +import { BaseRetriever } from "llamaindex"; // eslint-disable-next-line @typescript-eslint/no-explicit-any export type GenericFunction = (...args: any[]) => any; @@ -42,22 +8,12 @@ export type SafeFunction = ( ...args: Parameters ) => ReturnType | null; -export type LLMParameterAttributes = { - [SemanticConventions.LLM_MODEL_NAME]?: string; - [SemanticConventions.LLM_INVOCATION_PARAMETERS]?: string; -}; +export type ObjectWithModel = { model: string }; + +export type RetrieverQueryEngineQueryMethodType = + typeof llamaindex.RetrieverQueryEngine.prototype.query; -export type PromptTemplateAttributes = { - [SemanticConventions.PROMPT_TEMPLATE_TEMPLATE]?: string; - [SemanticConventions.PROMPT_TEMPLATE_VARIABLES]?: string; -}; -export type TokenCountAttributes = { - [SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]?: number; - [SemanticConventions.LLM_TOKEN_COUNT_PROMPT]?: number; - [SemanticConventions.LLM_TOKEN_COUNT_TOTAL]?: number; -}; +export type RetrieverRetrieveMethodType = BaseRetriever["retrieve"]; -export type ToolAttributes = { - [SemanticConventions.TOOL_NAME]?: string; - [SemanticConventions.TOOL_DESCRIPTION]?: string; -}; +export type QueryEmbeddingMethodType = + typeof llamaindex.BaseEmbedding.prototype.getQueryEmbedding; diff --git a/js/packages/openinference-instrumentation-llama-index/src/utils.ts b/js/packages/openinference-instrumentation-llama-index/src/utils.ts index b35a79e1a..50c1d2544 100644 --- a/js/packages/openinference-instrumentation-llama-index/src/utils.ts +++ b/js/packages/openinference-instrumentation-llama-index/src/utils.ts @@ -1,6 +1,5 @@ -import type * as llamaindex from "llamaindex"; +import * as llamaindex from "llamaindex"; -import { TextNode } from "llamaindex"; import { safeExecuteInTheMiddle } from "@opentelemetry/instrumentation"; import { Attributes, @@ -18,12 +17,20 @@ import { OpenInferenceSpanKind, SemanticConventions, } from "@arizeai/openinference-semantic-conventions"; -import { GenericFunction, SafeFunction } from "./types"; +import { + GenericFunction, + SafeFunction, + ObjectWithModel, + RetrieverQueryEngineQueryMethodType, + RetrieverRetrieveMethodType, + QueryEmbeddingMethodType, +} from "./types"; +import { BaseEmbedding, BaseRetriever } from "llamaindex"; /** * Wraps a function with a try-catch block to catch and log any errors. - * @param fn - A function to wrap with a try-catch block. - * @returns A function that returns null if an error is thrown. + * @param {T} fn - A function to wrap with a try-catch block. + * @returns {SafeFunction} A function that returns null if an error is thrown. */ export function withSafety(fn: T): SafeFunction { return (...args) => { @@ -35,13 +42,13 @@ export function withSafety(fn: T): SafeFunction { } }; } - const safelyJSONStringify = withSafety(JSON.stringify); /** - * Resolves the execution context for the current span - * If tracing is suppressed, the span is dropped and the current context is returned - * @param span + * Resolves the execution context for the current span. + * If tracing is suppressed, the span is dropped and the current context is returned. + * @param {Span} span - Tracer span + * @returns {Context} An execution context. */ function getExecContext(span: Span) { const activeContext = context.active(); @@ -56,14 +63,129 @@ function getExecContext(span: Span) { return execContext; } -export function patchQueryMethod( - original: typeof module.RetrieverQueryEngine.prototype.query, - module: typeof llamaindex, +/** + * If execution results in an error, push the error to the span. + * @param span + * @param error + */ +function handleError(span: Span, error: Error | undefined) { + if (error) { + span.recordException(error); + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error.message, + }); + span.end(); + } +} + +/** + * Checks whether the provided prototype is an instance of a `BaseRetriever`. + * + * @param {unknown} proto - The prototype to check. + * @returns {boolean} Whether the prototype is a `BaseRetriever`. + */ +export function isRetrieverPrototype(proto: unknown): proto is BaseRetriever { + return ( + proto != null && + typeof proto === "object" && + "retrieve" in proto && + typeof proto.retrieve === "function" + ); +} + +/** + * Checks whether the provided prototype is an instance of a `BaseEmbedding`. + * + * @param {unknown} proto - The prototype to check. + * @returns {boolean} Whether the prototype is a `BaseEmbedding`. + */ +export function isEmbeddingPrototype(proto: unknown): proto is BaseEmbedding { + return proto != null && proto instanceof BaseEmbedding; +} + +/** + * Extracts document attributes from an array of nodes with scores and returns extracted + * attributes in an Attributes object. + * + * @param {llamaindex.NodeWithScore[]} output - Array of nodes. + * @returns {Attributes} The extracted document attributes. + */ +function getDocumentAttributes( + output: llamaindex.NodeWithScore[], +) { + const docs: Attributes = {}; + output.forEach(({ node, score }, idx) => { + if (node instanceof llamaindex.TextNode) { + const prefix = `${SemanticConventions.RETRIEVAL_DOCUMENTS}.${idx}`; + docs[`${prefix}.${SemanticConventions.DOCUMENT_ID}`] = node.id_; + docs[`${prefix}.${SemanticConventions.DOCUMENT_SCORE}`] = score; + docs[`${prefix}.${SemanticConventions.DOCUMENT_CONTENT}`] = + node.getContent(); + docs[`${prefix}.${SemanticConventions.DOCUMENT_METADATA}`] = + safelyJSONStringify(node.metadata) ?? undefined; + } + }); + return docs; +} + +/** + * Extracts embedding information (input text and the output embedding vector), + * and constructs an Attributes object with the relevant semantic conventions + * for embeddings. + * + * @param {Object} embedInfo - The embedding information. + * @param {string} embedInfo.input - The input text for the embedding. + * @param {number[]} embedInfo.output - The output embedding vector. + * @returns {Attributes} The constructed embedding attributes. + */ +function getQueryEmbeddingAttributes(embedInfo: { + input: string; + output: number[]; +}): Attributes { + return { + [`${SemanticConventions.EMBEDDING_EMBEDDINGS}.0.${SemanticConventions.EMBEDDING_TEXT}`]: + embedInfo.input, + [`${SemanticConventions.EMBEDDING_EMBEDDINGS}.0.${SemanticConventions.EMBEDDING_VECTOR}`]: + embedInfo.output, + }; +} + +/** + * Checks if the provided class has a `model` property of type string + * as a class property. + * + * @param {unknown} cls - The class to check. + * @returns {boolean} Whether the object has a `model` property. + */ +function hasModelProperty(cls: unknown): cls is ObjectWithModel { + const objectWithModelMaybe = cls as ObjectWithModel; + return ( + "model" in objectWithModelMaybe && + typeof objectWithModelMaybe.model === "string" + ); +} + +/** + * Retrieves the value of the `model` property if the provided class + * implements it; otherwise, returns undefined. + * + * @param {unknown} cls - The class to retrieve the model name from. + * @returns {string | undefined} The model name or undefined. + */ +function getModelName(cls: unknown) { + if (hasModelProperty(cls)) { + return cls.model; + } +} + +export function patchQueryEngineQueryMethod( + original: RetrieverQueryEngineQueryMethodType, tracer: Tracer, ) { - return function patchedQuery( + return function ( this: unknown, - ...args: Parameters + ...args: Parameters ) { const span = tracer.startSpan(`query`, { kind: SpanKind.INTERNAL, @@ -78,24 +200,14 @@ export function patchQueryMethod( const execContext = getExecContext(span); const execPromise = safeExecuteInTheMiddle< - ReturnType + ReturnType >( () => { return context.with(execContext, () => { return original.apply(this, args); }); }, - (error) => { - // Push the error to the span - if (error) { - span.recordException(error); - span.setStatus({ - code: SpanStatusCode.ERROR, - message: error.message, - }); - span.end(); - } - }, + (error) => handleError(span, error), ); const wrappedPromise = execPromise.then((result) => { @@ -111,13 +223,12 @@ export function patchQueryMethod( } export function patchRetrieveMethod( - original: typeof module.VectorIndexRetriever.prototype.retrieve, - module: typeof llamaindex, + original: RetrieverRetrieveMethodType, tracer: Tracer, ) { - return function patchedRetrieve( + return function ( this: unknown, - ...args: Parameters + ...args: Parameters ) { const span = tracer.startSpan(`retrieve`, { kind: SpanKind.INTERNAL, @@ -132,28 +243,18 @@ export function patchRetrieveMethod( const execContext = getExecContext(span); const execPromise = safeExecuteInTheMiddle< - ReturnType + ReturnType >( () => { return context.with(execContext, () => { return original.apply(this, args); }); }, - (error) => { - // Push the error to the span - if (error) { - span.recordException(error); - span.setStatus({ - code: SpanStatusCode.ERROR, - message: error.message, - }); - span.end(); - } - }, + (error) => handleError(span, error), ); const wrappedPromise = execPromise.then((result) => { - span.setAttributes(documentAttributes(result)); + span.setAttributes(getDocumentAttributes(result)); span.end(); return result; }); @@ -161,25 +262,49 @@ export function patchRetrieveMethod( }; } -function documentAttributes( - output: llamaindex.NodeWithScore[], +export function patchQueryEmbeddingMethod( + original: QueryEmbeddingMethodType, + tracer: Tracer, ) { - const docs: Attributes = {}; - output.forEach((document, index) => { - const { node, score } = document; + return function ( + this: unknown, + ...args: Parameters + ) { + const span = tracer.startSpan(`embedding`, { + kind: SpanKind.INTERNAL, + attributes: { + [SemanticConventions.OPENINFERENCE_SPAN_KIND]: + OpenInferenceSpanKind.EMBEDDING, + }, + }); - if (node instanceof TextNode) { - const nodeId = node.id_; - const nodeText = node.getContent(); - const nodeMetadata = node.metadata; + const execContext = getExecContext(span); - const prefix = `${SemanticConventions.RETRIEVAL_DOCUMENTS}.${index}`; - docs[`${prefix}.${SemanticConventions.DOCUMENT_ID}`] = nodeId; - docs[`${prefix}.${SemanticConventions.DOCUMENT_SCORE}`] = score; - docs[`${prefix}.${SemanticConventions.DOCUMENT_CONTENT}`] = nodeText; - docs[`${prefix}.${SemanticConventions.DOCUMENT_METADATA}`] = - safelyJSONStringify(nodeMetadata) ?? undefined; - } - }); - return docs; + const execPromise = safeExecuteInTheMiddle< + ReturnType + >( + () => { + return context.with(execContext, () => { + return original.apply(this, args); + }); + }, + (error) => handleError(span, error), + ); + + // Model ID/name is a property found on the class and not in args + // Extract from class and set as attribute + span.setAttributes({ + [SemanticConventions.EMBEDDING_MODEL_NAME]: getModelName(this), + }); + + const wrappedPromise = execPromise.then((result) => { + const [query] = args; + span.setAttributes( + getQueryEmbeddingAttributes({ input: query, output: result }), + ); + span.end(); + return result; + }); + return context.bind(execContext, wrappedPromise); + }; } diff --git a/js/packages/openinference-instrumentation-llama-index/test/llamaIndex.test.ts b/js/packages/openinference-instrumentation-llama-index/test/llamaIndex.test.ts index e4c3ff165..e07e96ed2 100644 --- a/js/packages/openinference-instrumentation-llama-index/test/llamaIndex.test.ts +++ b/js/packages/openinference-instrumentation-llama-index/test/llamaIndex.test.ts @@ -1,28 +1,175 @@ +import * as llamaindex from "llamaindex"; + import { InMemorySpanExporter, SimpleSpanProcessor, } from "@opentelemetry/sdk-trace-base"; import { LlamaIndexInstrumentation, isPatched } from "../src/index"; import { NodeTracerProvider } from "@opentelemetry/sdk-trace-node"; -import * as llamaindex from "llamaindex"; import { SemanticConventions, OpenInferenceSpanKind, RETRIEVAL_DOCUMENTS, } from "@arizeai/openinference-semantic-conventions"; +import { + Document, + VectorStoreIndex, + GeminiEmbedding, + HuggingFaceEmbedding, + MistralAIEmbedding, + OllamaEmbedding, + OpenAIEmbedding, + RetrieverQueryEngine, +} from "llamaindex"; +import { isEmbeddingPrototype, isRetrieverPrototype } from "../src/utils"; +import { OpenAI } from "openai"; -const { Document, VectorStoreIndex } = llamaindex; - +// Mocked return values const DUMMY_RESPONSE = "lorem ipsum"; +const FAKE_EMBEDDING = Array(1536).fill(0); // Mock out the embeddings response to size 1536 (ada-2) +const RESPONSE_MESSAGE = { + content: DUMMY_RESPONSE, + role: "assistant", +}; + +const EMBEDDING_RESPONSE = { + object: "list", + data: [ + { object: "embedding", index: 0, embedding: FAKE_EMBEDDING }, + { object: "embedding", index: 0, embedding: FAKE_EMBEDDING }, + { object: "embedding", index: 0, embedding: FAKE_EMBEDDING }, + ], +}; + +// Mock out the chat completions endpoint +const CHAT_RESPONSE = { + id: "chatcmpl-8adq9JloOzNZ9TyuzrKyLpGXexh6p", + object: "chat.completion", + created: 1703743645, + model: "gpt-3.5-turbo-0613", + choices: [ + { + index: 0, + message: RESPONSE_MESSAGE, + logprobs: null, + finish_reason: "stop", + }, + ], + usage: { + prompt_tokens: 12, + completion_tokens: 5, + total_tokens: 17, + }, +}; const tracerProvider = new NodeTracerProvider(); tracerProvider.register(); const instrumentation = new LlamaIndexInstrumentation(); instrumentation.disable(); -describe("llamaIndex", () => { - it("should pass", () => { - expect(true).toBe(true); + +/* + * Tests for embeddings patching + */ +describe("LlamaIndexInstrumentation - Embeddings", () => { + const memoryExporter = new InMemorySpanExporter(); + const spanProcessor = new SimpleSpanProcessor(memoryExporter); + instrumentation.setTracerProvider(tracerProvider); + + tracerProvider.addSpanProcessor(spanProcessor); + // @ts-expect-error the moduleExports property is private. This is needed to make the test work with auto-mocking + instrumentation._modules[0].moduleExports = llamaindex; + + beforeAll(() => { + instrumentation.enable(); + process.env["OPENAI_API_KEY"] = "fake-api-key"; + process.env["MISTRAL_API_KEY"] = "fake-api-key"; + process.env["GOOGLE_API_KEY"] = "fake-api-key"; + }); + + afterAll(() => { + instrumentation.disable(); + delete process.env["OPENAI_API_KEY"]; + delete process.env["MISTRAL_API_KEY"]; + delete process.env["GOOGLE_API_KEY"]; + }); + beforeEach(() => { + jest + .spyOn(OpenAI.Embeddings.prototype, "create") + // @ts-expect-error the response type is not correct - this is just for testing + .mockImplementation(async (): Promise => { + return EMBEDDING_RESPONSE; + }); + + memoryExporter.reset(); + }); + afterEach(() => { + jest.clearAllMocks(); + jest.restoreAllMocks(); + jest.resetModules(); + }); + + it("is patched", () => { + expect(isPatched()).toBe(true); + }); + + it("isEmbeddingPrototype should identify retriever prototypes correctly", async () => { + // Expect all retriever prototypes to be identified as a retriever + expect(isRetrieverPrototype(RetrieverQueryEngine.prototype)).toEqual(true); + + // Expect a non-retriever object to be identified as such + expect(isRetrieverPrototype(HuggingFaceEmbedding.prototype)).toEqual(false); + }); + + it("isEmbeddingPrototype should identify embedder prototypes correctly", async () => { + // Expect all embedders to be identified as embeddings + expect(isEmbeddingPrototype(HuggingFaceEmbedding.prototype)).toEqual(true); + expect(isEmbeddingPrototype(GeminiEmbedding.prototype)).toEqual(true); + expect(isEmbeddingPrototype(MistralAIEmbedding.prototype)).toEqual(true); + expect(isEmbeddingPrototype(OpenAIEmbedding.prototype)).toEqual(true); + expect(isEmbeddingPrototype(OllamaEmbedding.prototype)).toEqual(true); + + // Expect a non-embedding object to be identified as such + expect(isEmbeddingPrototype({})).toEqual(false); + expect(isEmbeddingPrototype(null)).toEqual(false); + expect(isEmbeddingPrototype(undefined)).toEqual(false); + }); + + it("should create a span for embeddings (query)", async () => { + // Get embeddings + const embedder = new OpenAIEmbedding(); + const embeddedVector = await embedder.getQueryEmbedding( + "What did the author do in college?", + ); + + const spans = memoryExporter.getFinishedSpans(); + expect(spans.length).toBeGreaterThan(0); + + // Expect a span for the embedding + const queryEmbeddingSpan = spans.find((span) => + span.name.includes("embedding"), + ); + expect(queryEmbeddingSpan).toBeDefined(); + + // Verify span attributes + expect( + queryEmbeddingSpan?.attributes[ + SemanticConventions.OPENINFERENCE_SPAN_KIND + ], + ).toEqual(OpenInferenceSpanKind.EMBEDDING); + expect( + queryEmbeddingSpan?.attributes[ + `${SemanticConventions.EMBEDDING_EMBEDDINGS}.0.${SemanticConventions.EMBEDDING_TEXT}` + ], + ).toEqual("What did the author do in college?"); + expect( + queryEmbeddingSpan?.attributes[ + `${SemanticConventions.EMBEDDING_EMBEDDINGS}.0.${SemanticConventions.EMBEDDING_VECTOR}` + ], + ).toEqual(embeddedVector); + expect( + queryEmbeddingSpan?.attributes[SemanticConventions.EMBEDDING_MODEL_NAME], + ).toEqual("text-embedding-ada-002"); }); }); @@ -35,9 +182,6 @@ describe("LlamaIndexInstrumentation", () => { // @ts-expect-error the moduleExports property is private. This is needed to make the test work with auto-mocking instrumentation._modules[0].moduleExports = llamaindex; - let openAISpy: jest.SpyInstance; - let openAITextEmbedSpy: jest.SpyInstance; - let openAIQueryEmbedSpy: jest.SpyInstance; beforeAll(() => { instrumentation.enable(); @@ -52,46 +196,29 @@ describe("LlamaIndexInstrumentation", () => { beforeEach(() => { memoryExporter.reset(); - // Use OpenAI and mock out the calls - const response: llamaindex.ChatResponse = - { - message: { - content: DUMMY_RESPONSE, - role: "assistant", - }, - raw: null, - }; - // Mock out the chat completions endpoint - openAISpy = jest - .spyOn(llamaindex.OpenAI.prototype, "chat") - .mockImplementation(() => { - return Promise.resolve(response); - }); - - // Mock out the embeddings response to size 1536 (ada-2) - const fakeEmbedding = Array(1536).fill(0); - // Mock out the embeddings endpoint - openAITextEmbedSpy = jest - .spyOn(llamaindex.OpenAIEmbedding.prototype, "getTextEmbeddings") - .mockImplementation(() => { - return Promise.resolve([fakeEmbedding, fakeEmbedding, fakeEmbedding]); - }); - - openAIQueryEmbedSpy = jest - .spyOn(llamaindex.OpenAIEmbedding.prototype, "getQueryEmbedding") - .mockImplementation(() => { - return Promise.resolve(fakeEmbedding); - }); + jest.spyOn(OpenAI.Chat.Completions.prototype, "create").mockImplementation( + // @ts-expect-error the response type is not correct - this is just for testing + async (): Promise => { + return CHAT_RESPONSE; + }, + ); + jest.spyOn(OpenAI.Embeddings.prototype, "create").mockImplementation( + // @ts-expect-error the response type is not correct - this is just for testing + async (): Promise => { + return EMBEDDING_RESPONSE; + }, + ); }); afterEach(() => { jest.clearAllMocks(); - openAISpy.mockRestore(); - openAIQueryEmbedSpy.mockRestore(); - openAITextEmbedSpy.mockRestore(); + jest.restoreAllMocks(); + jest.resetModules(); }); + it("is patched", () => { expect(isPatched()).toBe(true); }); + it("should create a span for query engines", async () => { // Create Document object with essay const document = new Document({ text: "lorem ipsum" }); @@ -105,13 +232,8 @@ describe("LlamaIndexInstrumentation", () => { query: "What did the author do in college?", }); - // Verify that the OpenAI chat method was called once during synthesis - expect(openAISpy).toHaveBeenCalledTimes(1); - expect(openAIQueryEmbedSpy).toHaveBeenCalledTimes(1); - expect(openAITextEmbedSpy).toHaveBeenCalledTimes(1); - // Output response - expect(response.response).toEqual(DUMMY_RESPONSE); + expect(response.response).toEqual(RESPONSE_MESSAGE.content); const spans = memoryExporter.getFinishedSpans(); expect(spans.length).toBeGreaterThan(0); @@ -129,8 +251,9 @@ describe("LlamaIndexInstrumentation", () => { ).toEqual("What did the author do in college?"); expect( queryEngineSpan?.attributes[SemanticConventions.OUTPUT_VALUE], - ).toEqual(DUMMY_RESPONSE); + ).toEqual(RESPONSE_MESSAGE.content); }); + it("should create a span for retrieve method", async () => { // Create Document objects with essays const documents = [ @@ -149,11 +272,6 @@ describe("LlamaIndexInstrumentation", () => { query: "What did the author do in college?", }); - // OpenAI Chat method should not be called - expect(openAISpy).toHaveBeenCalledTimes(0); - expect(openAIQueryEmbedSpy).toHaveBeenCalledTimes(1); - expect(openAITextEmbedSpy).toHaveBeenCalledTimes(1); - const spans = memoryExporter.getFinishedSpans(); expect(spans.length).toBeGreaterThan(0); diff --git a/js/packages/openinference-vercel/tsconfig.json b/js/packages/openinference-vercel/tsconfig.json index 71640020d..05992796c 100644 --- a/js/packages/openinference-vercel/tsconfig.json +++ b/js/packages/openinference-vercel/tsconfig.json @@ -7,6 +7,6 @@ "lib": ["ESNext"] }, "files": [], - "include": ["src/**/*.ts", "test/**/*.ts"], + "include": ["src/**/*.ts", "test/**/*.ts", "examples/**/*.ts"], "references": [] }