From 8632b297580801896f70feefdcfd37bb6e9eb044 Mon Sep 17 00:00:00 2001 From: Chris Park Date: Thu, 1 Aug 2024 14:41:00 -0700 Subject: [PATCH] feat(js): llama-index-ts initial retrieval span support (#643) --- .../src/instrumentation.ts | 101 ++-------- .../src/types.ts | 63 ++++++ .../src/utils.ts | 185 ++++++++++++++++++ .../test/llamaIndex.test.ts | 90 ++++++++- 4 files changed, 352 insertions(+), 87 deletions(-) create mode 100644 js/packages/openinference-instrumentation-llama-index/src/types.ts create mode 100644 js/packages/openinference-instrumentation-llama-index/src/utils.ts diff --git a/js/packages/openinference-instrumentation-llama-index/src/instrumentation.ts b/js/packages/openinference-instrumentation-llama-index/src/instrumentation.ts index ccee4ce07..af73f1541 100644 --- a/js/packages/openinference-instrumentation-llama-index/src/instrumentation.ts +++ b/js/packages/openinference-instrumentation-llama-index/src/instrumentation.ts @@ -1,4 +1,3 @@ -/* eslint-disable @typescript-eslint/no-explicit-any */ import type * as llamaindex from "llamaindex"; import { @@ -6,18 +5,10 @@ import { InstrumentationConfig, InstrumentationModuleDefinition, InstrumentationNodeModuleDefinition, - safeExecuteInTheMiddle, } from "@opentelemetry/instrumentation"; +import { diag } from "@opentelemetry/api"; +import { patchQueryMethod, patchRetrieveMethod } from "./utils"; import { VERSION } from "./version"; -import { - Span, - SpanKind, - SpanStatusCode, - context, - diag, - trace, -} from "@opentelemetry/api"; -import { isTracingSuppressed } from "@opentelemetry/core"; const MODULE_NAME = "llamaindex"; @@ -34,28 +25,6 @@ export function isPatched() { return _isOpenInferencePatched; } -import { - OpenInferenceSpanKind, - SemanticConventions, -} from "@arizeai/openinference-semantic-conventions"; - -/** - * Resolves the execution context for the current span - * If tracing is suppressed, the span is dropped and the current context is returned - * @param span - */ -function getExecContext(span: Span) { - const activeContext = context.active(); - const suppressTracing = isTracingSuppressed(activeContext); - const execContext = suppressTracing - ? trace.setSpan(context.active(), span) - : activeContext; - // Drop the span from the context - if (suppressTracing) { - trace.deleteSpan(activeContext); - } - return execContext; -} export class LlamaIndexInstrumentation extends InstrumentationBase< typeof llamaindex > { @@ -66,7 +35,8 @@ export class LlamaIndexInstrumentation extends InstrumentationBase< Object.assign({}, config), ); } - manuallyInstrument(module: typeof llamaindex) { + + public manuallyInstrument(module: typeof llamaindex) { diag.debug(`Manually instrumenting ${MODULE_NAME}`); this.patch(module); } @@ -87,71 +57,32 @@ export class LlamaIndexInstrumentation extends InstrumentationBase< return moduleExports; } - // eslint-disable-next-line @typescript-eslint/no-this-alias - const instrumentation: LlamaIndexInstrumentation = this; - - type RetrieverQueryEngineQueryType = - typeof moduleExports.RetrieverQueryEngine.prototype.query; - + // TODO: Support streaming this._wrap( moduleExports.RetrieverQueryEngine.prototype, "query", - (original: RetrieverQueryEngineQueryType): any => { - return function patchedQuery( - this: unknown, - ...args: Parameters - ) { - const span = instrumentation.tracer.startSpan(`Query`, { - kind: SpanKind.INTERNAL, - attributes: { - [SemanticConventions.OPENINFERENCE_SPAN_KIND]: - OpenInferenceSpanKind.CHAIN, - [SemanticConventions.INPUT_VALUE]: args[0].query, - }, - }); - - const execContext = getExecContext(span); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (original): any => { + return patchQueryMethod(original, moduleExports, this.tracer); + }, + ); - const execPromise = safeExecuteInTheMiddle< - ReturnType - >( - () => { - return context.with(execContext, () => { - return original.apply(this, args); - }); - }, - (error) => { - // Push the error to the span - if (error) { - span.recordException(error); - span.setStatus({ - code: SpanStatusCode.ERROR, - message: error.message, - }); - span.end(); - } - }, - ); - const wrappedPromise = execPromise.then((result) => { - span.setAttributes({ - [SemanticConventions.OUTPUT_VALUE]: result.response, - }); - span.end(); - return result; - }); - return context.bind(execContext, wrappedPromise); - }; + this._wrap( + moduleExports.VectorIndexRetriever.prototype, + "retrieve", + (original) => { + return patchRetrieveMethod(original, moduleExports, this.tracer); }, ); _isOpenInferencePatched = true; - return moduleExports; } private unpatch(moduleExports: typeof llamaindex, moduleVersion?: string) { this._diag.debug(`Un-patching ${MODULE_NAME}@${moduleVersion}`); this._unwrap(moduleExports.RetrieverQueryEngine.prototype, "query"); + this._unwrap(moduleExports.VectorIndexRetriever.prototype, "retrieve"); _isOpenInferencePatched = false; } diff --git a/js/packages/openinference-instrumentation-llama-index/src/types.ts b/js/packages/openinference-instrumentation-llama-index/src/types.ts new file mode 100644 index 000000000..faafbff0a --- /dev/null +++ b/js/packages/openinference-instrumentation-llama-index/src/types.ts @@ -0,0 +1,63 @@ +import { SemanticConventions } from "@arizeai/openinference-semantic-conventions"; + +export type RetrievalDocument = { + [SemanticConventions.DOCUMENT_ID]?: string; + [SemanticConventions.DOCUMENT_CONTENT]?: string; + [SemanticConventions.DOCUMENT_SCORE]?: number | undefined; + [SemanticConventions.DOCUMENT_METADATA]?: string; +}; + +type LLMMessageToolCall = { + [SemanticConventions.TOOL_CALL_FUNCTION_NAME]?: string; + [SemanticConventions.TOOL_CALL_FUNCTION_ARGUMENTS_JSON]?: string; +}; + +export type LLMMessageToolCalls = { + [SemanticConventions.MESSAGE_TOOL_CALLS]?: LLMMessageToolCall[]; +}; + +export type LLMMessageFunctionCall = { + [SemanticConventions.MESSAGE_FUNCTION_CALL_NAME]?: string; + [SemanticConventions.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON]?: string; +}; + +export type LLMMessage = LLMMessageToolCalls & + LLMMessageFunctionCall & { + [SemanticConventions.MESSAGE_ROLE]?: string; + [SemanticConventions.MESSAGE_CONTENT]?: string; + }; + +export type LLMMessagesAttributes = + | { + [SemanticConventions.LLM_INPUT_MESSAGES]: LLMMessage[]; + } + | { + [SemanticConventions.LLM_OUTPUT_MESSAGES]: LLMMessage[]; + }; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type GenericFunction = (...args: any[]) => any; + +export type SafeFunction = ( + ...args: Parameters +) => ReturnType | null; + +export type LLMParameterAttributes = { + [SemanticConventions.LLM_MODEL_NAME]?: string; + [SemanticConventions.LLM_INVOCATION_PARAMETERS]?: string; +}; + +export type PromptTemplateAttributes = { + [SemanticConventions.PROMPT_TEMPLATE_TEMPLATE]?: string; + [SemanticConventions.PROMPT_TEMPLATE_VARIABLES]?: string; +}; +export type TokenCountAttributes = { + [SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]?: number; + [SemanticConventions.LLM_TOKEN_COUNT_PROMPT]?: number; + [SemanticConventions.LLM_TOKEN_COUNT_TOTAL]?: number; +}; + +export type ToolAttributes = { + [SemanticConventions.TOOL_NAME]?: string; + [SemanticConventions.TOOL_DESCRIPTION]?: string; +}; diff --git a/js/packages/openinference-instrumentation-llama-index/src/utils.ts b/js/packages/openinference-instrumentation-llama-index/src/utils.ts new file mode 100644 index 000000000..b35a79e1a --- /dev/null +++ b/js/packages/openinference-instrumentation-llama-index/src/utils.ts @@ -0,0 +1,185 @@ +import type * as llamaindex from "llamaindex"; + +import { TextNode } from "llamaindex"; +import { safeExecuteInTheMiddle } from "@opentelemetry/instrumentation"; +import { + Attributes, + Span, + SpanKind, + SpanStatusCode, + context, + trace, + Tracer, + diag, +} from "@opentelemetry/api"; +import { isTracingSuppressed } from "@opentelemetry/core"; +import { + MimeType, + OpenInferenceSpanKind, + SemanticConventions, +} from "@arizeai/openinference-semantic-conventions"; +import { GenericFunction, SafeFunction } from "./types"; + +/** + * Wraps a function with a try-catch block to catch and log any errors. + * @param fn - A function to wrap with a try-catch block. + * @returns A function that returns null if an error is thrown. + */ +export function withSafety(fn: T): SafeFunction { + return (...args) => { + try { + return fn(...args); + } catch (error) { + diag.error(`Failed to get attributes for span: ${error}`); + return null; + } + }; +} + +const safelyJSONStringify = withSafety(JSON.stringify); + +/** + * Resolves the execution context for the current span + * If tracing is suppressed, the span is dropped and the current context is returned + * @param span + */ +function getExecContext(span: Span) { + const activeContext = context.active(); + const suppressTracing = isTracingSuppressed(activeContext); + const execContext = suppressTracing + ? trace.setSpan(context.active(), span) + : activeContext; + // Drop the span from the context + if (suppressTracing) { + trace.deleteSpan(activeContext); + } + return execContext; +} + +export function patchQueryMethod( + original: typeof module.RetrieverQueryEngine.prototype.query, + module: typeof llamaindex, + tracer: Tracer, +) { + return function patchedQuery( + this: unknown, + ...args: Parameters + ) { + const span = tracer.startSpan(`query`, { + kind: SpanKind.INTERNAL, + attributes: { + [SemanticConventions.OPENINFERENCE_SPAN_KIND]: + OpenInferenceSpanKind.CHAIN, + [SemanticConventions.INPUT_VALUE]: args[0].query, + [SemanticConventions.INPUT_MIME_TYPE]: MimeType.TEXT, + }, + }); + + const execContext = getExecContext(span); + + const execPromise = safeExecuteInTheMiddle< + ReturnType + >( + () => { + return context.with(execContext, () => { + return original.apply(this, args); + }); + }, + (error) => { + // Push the error to the span + if (error) { + span.recordException(error); + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error.message, + }); + span.end(); + } + }, + ); + + const wrappedPromise = execPromise.then((result) => { + span.setAttributes({ + [SemanticConventions.OUTPUT_VALUE]: result.response, + [SemanticConventions.OUTPUT_MIME_TYPE]: MimeType.TEXT, + }); + span.end(); + return result; + }); + return context.bind(execContext, wrappedPromise); + }; +} + +export function patchRetrieveMethod( + original: typeof module.VectorIndexRetriever.prototype.retrieve, + module: typeof llamaindex, + tracer: Tracer, +) { + return function patchedRetrieve( + this: unknown, + ...args: Parameters + ) { + const span = tracer.startSpan(`retrieve`, { + kind: SpanKind.INTERNAL, + attributes: { + [SemanticConventions.OPENINFERENCE_SPAN_KIND]: + OpenInferenceSpanKind.RETRIEVER, + [SemanticConventions.INPUT_VALUE]: args[0].query, + [SemanticConventions.INPUT_MIME_TYPE]: MimeType.TEXT, + }, + }); + + const execContext = getExecContext(span); + + const execPromise = safeExecuteInTheMiddle< + ReturnType + >( + () => { + return context.with(execContext, () => { + return original.apply(this, args); + }); + }, + (error) => { + // Push the error to the span + if (error) { + span.recordException(error); + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error.message, + }); + span.end(); + } + }, + ); + + const wrappedPromise = execPromise.then((result) => { + span.setAttributes(documentAttributes(result)); + span.end(); + return result; + }); + return context.bind(execContext, wrappedPromise); + }; +} + +function documentAttributes( + output: llamaindex.NodeWithScore[], +) { + const docs: Attributes = {}; + output.forEach((document, index) => { + const { node, score } = document; + + if (node instanceof TextNode) { + const nodeId = node.id_; + const nodeText = node.getContent(); + const nodeMetadata = node.metadata; + + const prefix = `${SemanticConventions.RETRIEVAL_DOCUMENTS}.${index}`; + docs[`${prefix}.${SemanticConventions.DOCUMENT_ID}`] = nodeId; + docs[`${prefix}.${SemanticConventions.DOCUMENT_SCORE}`] = score; + docs[`${prefix}.${SemanticConventions.DOCUMENT_CONTENT}`] = nodeText; + docs[`${prefix}.${SemanticConventions.DOCUMENT_METADATA}`] = + safelyJSONStringify(nodeMetadata) ?? undefined; + } + }); + return docs; +} diff --git a/js/packages/openinference-instrumentation-llama-index/test/llamaIndex.test.ts b/js/packages/openinference-instrumentation-llama-index/test/llamaIndex.test.ts index 4719e10c9..e4c3ff165 100644 --- a/js/packages/openinference-instrumentation-llama-index/test/llamaIndex.test.ts +++ b/js/packages/openinference-instrumentation-llama-index/test/llamaIndex.test.ts @@ -5,6 +5,11 @@ import { import { LlamaIndexInstrumentation, isPatched } from "../src/index"; import { NodeTracerProvider } from "@opentelemetry/sdk-trace-node"; import * as llamaindex from "llamaindex"; +import { + SemanticConventions, + OpenInferenceSpanKind, + RETRIEVAL_DOCUMENTS, +} from "@arizeai/openinference-semantic-conventions"; const { Document, VectorStoreIndex } = llamaindex; @@ -69,7 +74,7 @@ describe("LlamaIndexInstrumentation", () => { openAITextEmbedSpy = jest .spyOn(llamaindex.OpenAIEmbedding.prototype, "getTextEmbeddings") .mockImplementation(() => { - return Promise.resolve([fakeEmbedding]); + return Promise.resolve([fakeEmbedding, fakeEmbedding, fakeEmbedding]); }); openAIQueryEmbedSpy = jest @@ -112,7 +117,88 @@ describe("LlamaIndexInstrumentation", () => { expect(spans.length).toBeGreaterThan(0); // Expect a span for the query engine - const queryEngineSpan = spans.find((span) => span.name.includes("Query")); + const queryEngineSpan = spans.find((span) => span.name.includes("query")); expect(queryEngineSpan).toBeDefined(); + + // Verify query span attributes + expect( + queryEngineSpan?.attributes[SemanticConventions.OPENINFERENCE_SPAN_KIND], + ).toEqual(OpenInferenceSpanKind.CHAIN); + expect( + queryEngineSpan?.attributes[SemanticConventions.INPUT_VALUE], + ).toEqual("What did the author do in college?"); + expect( + queryEngineSpan?.attributes[SemanticConventions.OUTPUT_VALUE], + ).toEqual(DUMMY_RESPONSE); + }); + it("should create a span for retrieve method", async () => { + // Create Document objects with essays + const documents = [ + new Document({ text: "lorem ipsum 1" }), + new Document({ text: "lorem ipsum 2" }), + new Document({ text: "lorem ipsum 3" }), + ]; + + // Split text and create embeddings. Store them in a VectorStoreIndex + const index = await VectorStoreIndex.fromDocuments(documents); + + // Retrieve documents from the index + const retriever = index.asRetriever(); + + const response = await retriever.retrieve({ + query: "What did the author do in college?", + }); + + // OpenAI Chat method should not be called + expect(openAISpy).toHaveBeenCalledTimes(0); + expect(openAIQueryEmbedSpy).toHaveBeenCalledTimes(1); + expect(openAITextEmbedSpy).toHaveBeenCalledTimes(1); + + const spans = memoryExporter.getFinishedSpans(); + expect(spans.length).toBeGreaterThan(0); + + // Expect a span for the retrieve method + const retrievalSpan = spans.find((span) => span.name.includes("retrieve")); + expect(retrievalSpan).toBeDefined(); + + // Verify query span attributes + expect( + retrievalSpan?.attributes[SemanticConventions.OPENINFERENCE_SPAN_KIND], + ).toEqual(OpenInferenceSpanKind.RETRIEVER); + expect(retrievalSpan?.attributes[SemanticConventions.INPUT_VALUE]).toEqual( + "What did the author do in college?", + ); + + // Check document attributes + response.forEach((document, index) => { + const { node, score } = document; + + if (node instanceof llamaindex.TextNode) { + const nodeId = node.id_; + const nodeText = node.getContent(); + const nodeMetadata = node.metadata; + + expect( + retrievalSpan?.attributes[ + `${RETRIEVAL_DOCUMENTS}.${index}.document.id` + ], + ).toEqual(nodeId); + expect( + retrievalSpan?.attributes[ + `${RETRIEVAL_DOCUMENTS}.${index}.document.score` + ], + ).toEqual(score); + expect( + retrievalSpan?.attributes[ + `${RETRIEVAL_DOCUMENTS}.${index}.document.content` + ], + ).toEqual(nodeText); + expect( + retrievalSpan?.attributes[ + `${RETRIEVAL_DOCUMENTS}.${index}.document.metadata` + ], + ).toEqual(JSON.stringify(nodeMetadata)); + } + }); }); });