From 0cef233aa3ad60a17bbc28e4c80d1d1bb859f360 Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Thu, 12 Dec 2024 08:06:42 -0800 Subject: [PATCH] fix: handle multiple embedding events for llama-index (#1166) --- .../instrumentation/llama_index/_handler.py | 8 ++- .../instrumentation/llama_index/conftest.py | 18 +++++++ .../test_multiple_embedding_events.py | 49 +++++++++++++++++++ 3 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 python/instrumentation/openinference-instrumentation-llama-index/tests/openinference/instrumentation/llama_index/conftest.py create mode 100644 python/instrumentation/openinference-instrumentation-llama-index/tests/openinference/instrumentation/llama_index/test_multiple_embedding_events.py diff --git a/python/instrumentation/openinference-instrumentation-llama-index/src/openinference/instrumentation/llama_index/_handler.py b/python/instrumentation/openinference-instrumentation-llama-index/src/openinference/instrumentation/llama_index/_handler.py index aadbbe7fb..c2a031122 100644 --- a/python/instrumentation/openinference-instrumentation-llama-index/src/openinference/instrumentation/llama_index/_handler.py +++ b/python/instrumentation/openinference-instrumentation-llama-index/src/openinference/instrumentation/llama_index/_handler.py @@ -4,6 +4,7 @@ import json import logging import weakref +from collections import defaultdict from dataclasses import dataclass from enum import Enum, auto from functools import singledispatch, singledispatchmethod @@ -15,6 +16,7 @@ TYPE_CHECKING, Any, AsyncGenerator, + DefaultDict, Dict, Generator, Iterable, @@ -183,6 +185,7 @@ def __init__( self._attributes = {} self._end_time = None self._last_updated_at = time() + self._list_attr_len: DefaultDict[str, int] = defaultdict(int) def __setitem__(self, key: str, value: AttributeValue) -> None: self._attributes[key] = value @@ -362,9 +365,12 @@ def _(self, event: EmbeddingStartEvent) -> None: @_process_event.register def _(self, event: EmbeddingEndEvent) -> None: - for i, (text, vector) in enumerate(zip(event.chunks, event.embeddings)): + i = self._list_attr_len[EMBEDDING_EMBEDDINGS] + for text, vector in zip(event.chunks, event.embeddings): self[f"{EMBEDDING_EMBEDDINGS}.{i}.{EMBEDDING_TEXT}"] = text self[f"{EMBEDDING_EMBEDDINGS}.{i}.{EMBEDDING_VECTOR}"] = vector + i += 1 + self._list_attr_len[EMBEDDING_EMBEDDINGS] = i @_process_event.register def _(self, event: StreamChatStartEvent) -> None: diff --git a/python/instrumentation/openinference-instrumentation-llama-index/tests/openinference/instrumentation/llama_index/conftest.py b/python/instrumentation/openinference-instrumentation-llama-index/tests/openinference/instrumentation/llama_index/conftest.py new file mode 100644 index 000000000..b5fc7b480 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-llama-index/tests/openinference/instrumentation/llama_index/conftest.py @@ -0,0 +1,18 @@ +import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + +@pytest.fixture +def in_memory_span_exporter() -> InMemorySpanExporter: + return InMemorySpanExporter() + + +@pytest.fixture +def tracer_provider( + in_memory_span_exporter: InMemorySpanExporter, +) -> TracerProvider: + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(SimpleSpanProcessor(in_memory_span_exporter)) + return tracer_provider diff --git a/python/instrumentation/openinference-instrumentation-llama-index/tests/openinference/instrumentation/llama_index/test_multiple_embedding_events.py b/python/instrumentation/openinference-instrumentation-llama-index/tests/openinference/instrumentation/llama_index/test_multiple_embedding_events.py new file mode 100644 index 000000000..31f8260b3 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-llama-index/tests/openinference/instrumentation/llama_index/test_multiple_embedding_events.py @@ -0,0 +1,49 @@ +from itertools import product +from typing import Iterator + +import pytest +from llama_index.core.instrumentation import get_dispatcher +from llama_index.core.instrumentation.events.embedding import EmbeddingEndEvent +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from openinference.instrumentation.llama_index import LlamaIndexInstrumentor +from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes + +dispatcher = get_dispatcher(__name__) + + +@dispatcher.span # type: ignore[misc,unused-ignore] +def foo(m: int, n: int) -> None: + for i in range(m): + chunks = [f"{i}-{j}" for j in range(n)] + embeddings = [list(map(float, [i, j])) for j in range(n)] + dispatcher.event(EmbeddingEndEvent(chunks=chunks, embeddings=embeddings)) + + +async def test_multiple_embedding_events( + in_memory_span_exporter: InMemorySpanExporter, +) -> None: + m, n = 3, 2 + foo(m, n) + span = in_memory_span_exporter.get_finished_spans()[0] + assert span.attributes + for k, (i, j) in enumerate(product(range(m), range(n))): + text, vector = f"{i}-{j}", tuple(map(float, [i, j])) + assert span.attributes[f"{EMBEDDING_EMBEDDINGS}.{k}.{EMBEDDING_TEXT}"] == text + assert span.attributes[f"{EMBEDDING_EMBEDDINGS}.{k}.{EMBEDDING_VECTOR}"] == vector + + +@pytest.fixture(autouse=True) +def instrument( + tracer_provider: TracerProvider, + in_memory_span_exporter: InMemorySpanExporter, +) -> Iterator[None]: + LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider) + yield + LlamaIndexInstrumentor().uninstrument() + + +EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS +EMBEDDING_TEXT = EmbeddingAttributes.EMBEDDING_TEXT +EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR