diff --git a/python/.gitignore b/python/.gitignore new file mode 100644 index 000000000..9e80b2408 --- /dev/null +++ b/python/.gitignore @@ -0,0 +1,2 @@ +# vendored virtual environments +.venv diff --git a/python/dev-requirements.txt b/python/dev-requirements.txt index 4abb5b680..9ea8aaff2 100644 --- a/python/dev-requirements.txt +++ b/python/dev-requirements.txt @@ -1,3 +1,3 @@ +mypy == 1.8.0 pytest == 7.4.4 ruff == 0.1.11 -mypy == 1.8.0 diff --git a/python/instrumentation/openinference-instrumentation-dspy/examples/rag_module.py b/python/instrumentation/openinference-instrumentation-dspy/examples/rag_module.py new file mode 100644 index 000000000..55a2af744 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-dspy/examples/rag_module.py @@ -0,0 +1,56 @@ +import logging +import os +import sys + +import dspy +from openinference.instrumentation.dspy import DSPyInstrumentor +from opentelemetry import trace as trace_api +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor + +logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) + +resource = Resource(attributes={}) +tracer_provider = trace_sdk.TracerProvider(resource=resource) +span_console_exporter = ConsoleSpanExporter() +tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter=span_console_exporter)) + +# Logs to the Phoenix Collector if running locally +if phoenix_collector_endpoint := os.environ.get("PHOENIX_COLLECTOR_ENDPOINT"): + endpoint = phoenix_collector_endpoint + "/v1/traces" + span_otlp_exporter = OTLPSpanExporter(endpoint=endpoint) + tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter=span_otlp_exporter)) + + +trace_api.set_tracer_provider(tracer_provider=tracer_provider) +DSPyInstrumentor().instrument() + + +class BasicQA(dspy.Signature): + answer = dspy.OutputField(desc="often between 1 and 5 words") + + +class RAG(dspy.Module): + def __init__(self, num_passages=3): + super().__init__() + self.retrieve = dspy.Retrieve(k=num_passages) + self.generate_answer = dspy.ChainOfThought(BasicQA) + + def forward(self, question): + context = self.retrieve(question).passages + prediction = self.generate_answer(context=context, question=question) + return dspy.Prediction(context=context, answer=prediction.answer) + + +if __name__ == "__main__": + turbo = dspy.OpenAI(model="gpt-3.5-turbo") + colbertv2_wiki17_abstracts = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts") + dspy.settings.configure( + lm=turbo, + rm=colbertv2_wiki17_abstracts, + ) + rag = RAG() + output = rag("What's the capital of the united states?") + print(output) diff --git a/python/instrumentation/openinference-instrumentation-dspy/pyproject.toml b/python/instrumentation/openinference-instrumentation-dspy/pyproject.toml index 5d9a687ba..5ef99615e 100644 --- a/python/instrumentation/openinference-instrumentation-dspy/pyproject.toml +++ b/python/instrumentation/openinference-instrumentation-dspy/pyproject.toml @@ -37,7 +37,7 @@ instruments = [ test = [ "dspy-ai==2.1.0", "opentelemetry-sdk", - "requests-mock", + "responses", ] [project.urls] diff --git a/python/instrumentation/openinference-instrumentation-dspy/src/openinference/instrumentation/dspy/__init__.py b/python/instrumentation/openinference-instrumentation-dspy/src/openinference/instrumentation/dspy/__init__.py index 765651368..e64d734df 100644 --- a/python/instrumentation/openinference-instrumentation-dspy/src/openinference/instrumentation/dspy/__init__.py +++ b/python/instrumentation/openinference-instrumentation-dspy/src/openinference/instrumentation/dspy/__init__.py @@ -1,16 +1,20 @@ import json from abc import ABC -from typing import Any, Callable, Collection, Dict, Mapping, Tuple +from enum import Enum +from inspect import signature +from typing import Any, Callable, Collection, Dict, Iterator, List, Mapping, Tuple from openinference.instrumentation.dspy.package import _instruments from openinference.instrumentation.dspy.version import __version__ from openinference.semconv.trace import ( + DocumentAttributes, OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes, ) from opentelemetry import trace as trace_api from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore +from opentelemetry.util.types import AttributeValue from wrapt import wrap_function_wrapper _DSPY_MODULE = "dspy" @@ -35,6 +39,8 @@ def _instrument(self, **kwargs: Any) -> None: # Instrument LM (language model) calls from dsp.modules.lm import LM + from dspy import Predict + language_model_classes = LM.__subclasses__() for lm in language_model_classes: wrap_function_wrapper( @@ -43,13 +49,48 @@ def _instrument(self, **kwargs: Any) -> None: wrapper=_LMBasicRequestWrapper(tracer), ) - # Instrument DSPy constructs + # Predict is a concrete (non-abstract) class that may be invoked + # directly, but DSPy also has subclasses of Predict that override the + # forward method. We instrument both the forward methods of the base + # class and all subclasses. wrap_function_wrapper( module=_DSPY_MODULE, name="Predict.forward", wrapper=_PredictForwardWrapper(tracer), ) + predict_subclasses = Predict.__subclasses__() + for predict_subclass in predict_subclasses: + wrap_function_wrapper( + module=_DSPY_MODULE, + name=predict_subclass.__name__ + ".forward", + wrapper=_PredictForwardWrapper(tracer), + ) + + wrap_function_wrapper( + module=_DSPY_MODULE, + name="Retrieve.forward", + wrapper=_RetrieverForwardWrapper(tracer), + ) + + wrap_function_wrapper( + module=_DSPY_MODULE, + # At this time, dspy.Module does not have an abstract forward + # method, but assumes that user-defined subclasses implement the + # forward method and invokes that method using __call__. + name="Module.__call__", + wrapper=_ModuleForwardWrapper(tracer), + ) + + # At this time, there is no common parent class for retriever models as + # there is for language models. We instrument the retriever models on a + # case-by-case basis. + wrap_function_wrapper( + module=_DSP_MODULE, + name="ColBERTv2.__call__", + wrapper=_RetrieverModelCallWrapper(tracer), + ) + def _uninstrument(self, **kwargs: Any) -> None: from dsp.modules.lm import LM @@ -93,13 +134,17 @@ def __call__( span_name = instance.__class__.__name__ + ".request" with self._tracer.start_as_current_span( span_name, - attributes={ - SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value, - SpanAttributes.LLM_MODEL_NAME: instance.kwargs.get("model"), - SpanAttributes.LLM_INVOCATION_PARAMETERS: json.dumps(invocation_parameters), - SpanAttributes.INPUT_VALUE: str(prompt), - SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, - }, + attributes=dict( + _flatten( + { + OPENINFERENCE_SPAN_KIND: LLM.value, + LLM_MODEL_NAME: instance.kwargs.get("model"), + LLM_INVOCATION_PARAMETERS: json.dumps(invocation_parameters), + INPUT_VALUE: str(prompt), + INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, + } + ) + ), ) as span: try: response = wrapped(*args, **kwargs) @@ -110,10 +155,14 @@ def __call__( # TODO: parse usage. Need to decide if this # instrumentation should be used in conjunction with model instrumentation span.set_attributes( - { - SpanAttributes.OUTPUT_VALUE: json.dumps(response), - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - } + dict( + _flatten( + { + OUTPUT_VALUE: json.dumps(response), + OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + ) ) span.set_status(trace_api.StatusCode.OK) return response @@ -131,15 +180,48 @@ def __call__( args: Tuple[type, Any], kwargs: Mapping[str, Any], ) -> Any: + from dspy import Predict + + # At this time, subclasses of Predict override the base class' forward + # method and invoke the parent class' forward method from within the + # overridden method. The forward method for both Predict and its + # subclasses have been instrumented. To avoid creating duplicate spans + # for a single invocation, we don't create a span for the base class' + # forward method if the instance belongs to a proper subclass of Predict + # with an overridden forward method. + is_instance_of_predict_subclass = ( + isinstance(instance, Predict) and (cls := instance.__class__) is not Predict + ) + has_overridden_forward_method = getattr(cls, "forward", None) is not getattr( + Predict, "forward", None + ) + wrapped_method_is_base_class_forward_method = ( + wrapped.__qualname__ == Predict.forward.__qualname__ + ) + if ( + is_instance_of_predict_subclass + and has_overridden_forward_method + and wrapped_method_is_base_class_forward_method + ): + return wrapped(*args, **kwargs) + signature = kwargs.get("signature", instance.signature) - span_name = signature.__name__ + ".forward" + span_name = _get_predict_span_name(instance) with self._tracer.start_as_current_span( span_name, - attributes={ - SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.INPUT_VALUE: json.dumps(kwargs), - SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - }, + attributes=dict( + _flatten( + { + OPENINFERENCE_SPAN_KIND: CHAIN.value, + INPUT_VALUE: _get_input_value( + wrapped, + *args, + **kwargs, + ), + INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + ), ) as span: try: prediction = wrapped(*args, **kwargs) @@ -148,12 +230,16 @@ def __call__( span.record_exception(exception) raise span.set_attributes( - { - SpanAttributes.OUTPUT_VALUE: json.dumps( - self._prediction_to_output_dict(prediction, signature) - ), - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - } + dict( + _flatten( + { + OUTPUT_VALUE: json.dumps( + self._prediction_to_output_dict(prediction, signature) + ), + OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + ) ) span.set_status(trace_api.StatusCode.OK) return prediction @@ -167,3 +253,269 @@ def _prediction_to_output_dict(self, prediction: Any, signature: Any) -> Dict[st if field.output_variable and field.output_variable in prediction: output[field.output_variable] = prediction.get(field.output_variable) return output + + +class _ModuleForwardWrapper(_WithTracer): + """ + Instruments the __call__ method of dspy.Module. DSPy end users define custom + subclasses of Module implementing a forward method, loosely resembling the + ergonomics of torch.nn.Module. The __call__ method of dspy.Module invokes + the forward method of the user-defined subclass. + """ + + def __call__( + self, + wrapped: Callable[..., Any], + instance: Any, + args: Tuple[type, Any], + kwargs: Mapping[str, Any], + ) -> Any: + span_name = instance.__class__.__name__ + ".forward" + with self._tracer.start_as_current_span( + span_name, + attributes=dict( + _flatten( + { + OPENINFERENCE_SPAN_KIND: CHAIN.value, + # At this time, dspy.Module does not have an abstract forward + # method, but assumes that user-defined subclasses implement the + # forward method. + **( + {INPUT_VALUE: _get_input_value(forward_method, *args, **kwargs)} + if (forward_method := getattr(instance.__class__, "forward", None)) + else {} + ), + INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + ), + ) as span: + try: + prediction = wrapped(*args, **kwargs) + except Exception as exception: + span.set_status(trace_api.Status(trace_api.StatusCode.ERROR, str(exception))) + span.record_exception(exception) + raise + span.set_attributes( + dict( + _flatten( + { + OUTPUT_VALUE: json.dumps(prediction, cls=DSPyJSONEncoder), + OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + ) + ) + span.set_status(trace_api.StatusCode.OK) + return prediction + + +class _RetrieverForwardWrapper(_WithTracer): + """ + Instruments the forward method of dspy.Retrieve, which is a wrapper around + retriever models such as ColBERTv2. At this time, Retrieve does not contain + any additional information that cannot be gleaned from the underlying + retriever model sub-span. It is, however, a user-facing concept, so we have + decided to instrument it. + """ + + def __call__( + self, + wrapped: Callable[..., Any], + instance: Any, + args: Tuple[type, Any], + kwargs: Mapping[str, Any], + ) -> Any: + span_name = instance.__class__.__name__ + ".forward" + with self._tracer.start_as_current_span( + span_name, + attributes=dict( + _flatten( + { + OPENINFERENCE_SPAN_KIND: RETRIEVER.value, + INPUT_VALUE: _get_input_value(wrapped, *args, **kwargs), + INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + ), + ) as span: + try: + prediction = wrapped(*args, **kwargs) + except Exception as exception: + span.set_status(trace_api.Status(trace_api.StatusCode.ERROR, str(exception))) + span.record_exception(exception) + raise + span.set_attributes( + dict( + _flatten( + { + RETRIEVAL_DOCUMENTS: [ + { + DocumentAttributes.DOCUMENT_CONTENT: document_text, + } + for document_text in prediction.get("passages", []) + ], + } + ) + ) + ) + span.set_status(trace_api.StatusCode.OK) + return prediction + + +class _RetrieverModelCallWrapper(_WithTracer): + """ + Instruments the __call__ method of retriever models such as ColBERTv2. + """ + + def __call__( + self, + wrapped: Callable[..., Any], + instance: Any, + args: Tuple[type, Any], + kwargs: Mapping[str, Any], + ) -> Any: + class_name = instance.__class__.__name__ + span_name = class_name + ".__call__" + with self._tracer.start_as_current_span( + span_name, + attributes=dict( + _flatten( + { + OPENINFERENCE_SPAN_KIND: RETRIEVER.value, + INPUT_VALUE: (_get_input_value(wrapped, *args, **kwargs)), + INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + ), + ) as span: + try: + retrieved_documents = wrapped(*args, **kwargs) + except Exception as exception: + span.set_status(trace_api.Status(trace_api.StatusCode.ERROR, str(exception))) + span.record_exception(exception) + raise + span.set_attributes( + dict( + _flatten( + { + RETRIEVAL_DOCUMENTS: [ + { + DocumentAttributes.DOCUMENT_ID: doc["pid"], + DocumentAttributes.DOCUMENT_CONTENT: doc["text"], + DocumentAttributes.DOCUMENT_SCORE: doc["score"], + } + for doc in retrieved_documents + ], + } + ) + ) + ) + span.set_status(trace_api.StatusCode.OK) + return retrieved_documents + + +class DSPyJSONEncoder(json.JSONEncoder): + """ + Provides support for non-JSON-serializable objects in DSPy. + """ + + def default(self, o: Any) -> Any: + try: + return super().default(o) + except TypeError: + from dsp.templates.template_v3 import Template + + from dspy.primitives.example import Example + + if hasattr(o, "_asdict"): + # convert namedtuples to dictionaries + return o._asdict() + if isinstance(o, Example): + # handles Prediction objects and other sub-classes of Example + return getattr(o, "_store", {}) + if isinstance(o, Template): + return { + "fields": [self.default(field) for field in o.fields], + "instructions": o.instructions, + } + return repr(o) + + +def _get_input_value(method: Callable[..., Any], *args: Any, **kwargs: Any) -> str: + """ + Parses a method call's inputs into a JSON string. Ensures a consistent + output regardless of whether the those inputs are passed as positional or + keyword arguments. + """ + + # For typical class methods, the corresponding instance of inspect.Signature + # does not include the self parameter. However, the inspect.Signature + # instance for __call__ does include the self parameter. + method_signature = signature(method) + first_parameter_name = next(iter(method_signature.parameters), None) + signature_contains_self_parameter = first_parameter_name in ["self"] + bound_arguments = method_signature.bind( + *( + [None] # the value bound to the method's self argument is discarded below, so pass None + if signature_contains_self_parameter + else [] # no self parameter, so no need to pass a value + ), + *args, + **kwargs, + ) + return json.dumps( + { + **{ + argument_name: argument_value + for argument_name, argument_value in bound_arguments.arguments.items() + if argument_name not in ["self", "kwargs"] + }, + **bound_arguments.arguments.get("kwargs", {}), + }, + cls=DSPyJSONEncoder, + ) + + +def _get_predict_span_name(instance: Any) -> str: + """ + Gets the name for the Predict span, which are the composition of a Predict + class or subclass and a user-defined signature. An example name would be + "Predict(UserDefinedSignature).forward". + """ + class_name = str(instance.__class__.__name__) + if (signature := getattr(instance, "signature", None)) and ( + signature_name := getattr(signature, "__name__", None) + ): + return f"{class_name}({signature_name}).forward" + return f"{class_name}.forward" + + +def _flatten(mapping: Mapping[str, Any]) -> Iterator[Tuple[str, AttributeValue]]: + for key, value in mapping.items(): + if value is None: + continue + if isinstance(value, Mapping): + for sub_key, sub_value in _flatten(value): + yield f"{key}.{sub_key}", sub_value + elif isinstance(value, List) and any(isinstance(item, Mapping) for item in value): + for index, sub_mapping in enumerate(value): + for sub_key, sub_value in _flatten(sub_mapping): + yield f"{key}.{index}.{sub_key}", sub_value + else: + if isinstance(value, Enum): + value = value.value + yield key, value + + +OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND +RETRIEVER = OpenInferenceSpanKindValues.RETRIEVER +CHAIN = OpenInferenceSpanKindValues.CHAIN +LLM = OpenInferenceSpanKindValues.LLM +INPUT_VALUE = SpanAttributes.INPUT_VALUE +INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE +OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE +OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE +LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS +LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME +RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS diff --git a/python/instrumentation/openinference-instrumentation-dspy/tests/openinference/instrumentation/dspy/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-dspy/tests/openinference/instrumentation/dspy/test_instrumentor.py index 49d9328bf..79be117d4 100644 --- a/python/instrumentation/openinference-instrumentation-dspy/tests/openinference/instrumentation/dspy/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-dspy/tests/openinference/instrumentation/dspy/test_instrumentor.py @@ -1,10 +1,15 @@ +import json from typing import Generator import dspy import pytest -import requests_mock +import responses +from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory from openinference.instrumentation.dspy import DSPyInstrumentor from openinference.semconv.trace import ( + DocumentAttributes, + OpenInferenceMimeTypeValues, + OpenInferenceSpanKindValues, SpanAttributes, ) from opentelemetry import trace as trace_api @@ -14,12 +19,12 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -@pytest.fixture(scope="module") +@pytest.fixture() def in_memory_span_exporter() -> InMemorySpanExporter: return InMemorySpanExporter() -@pytest.fixture(scope="module") +@pytest.fixture() def tracer_provider(in_memory_span_exporter: InMemorySpanExporter) -> trace_api.TracerProvider: resource = Resource(attributes={}) tracer_provider = trace_sdk.TracerProvider(resource=resource) @@ -39,6 +44,18 @@ def instrument( in_memory_span_exporter.clear() +@pytest.fixture(autouse=True) +def clear_cache() -> None: + """ + DSPy caches responses from retrieval and language models to disk. This + fixture clears the cache before each test case to ensure that our mocked + responses are used. + """ + CacheMemory.clear() + NotebookCacheMemory.clear() + + +@responses.activate def test_openai_lm( in_memory_span_exporter: InMemorySpanExporter, ) -> None: @@ -52,41 +69,262 @@ class BasicQA(dspy.Signature): # type: ignore dspy.settings.configure(lm=turbo) # Mock out the OpenAI API. - url = "https://api.openai.com/v1/chat/completions" - response = { - "id": "chatcmpl-8kKarJQUyeuFeRsj18o6TWrxoP2zs", - "object": "chat.completion", - "created": 1706052941, - "model": "gpt-3.5-turbo-0613", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Washington DC", - }, - "logprobs": None, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 39, "completion_tokens": 396, "total_tokens": 435}, - "system_fingerprint": None, - } + responses.add( + method=responses.POST, + url="https://api.openai.com/v1/chat/completions", + json={ + "id": "chatcmpl-8kKarJQUyeuFeRsj18o6TWrxoP2zs", + "object": "chat.completion", + "created": 1706052941, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Washington DC", + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 39, "completion_tokens": 396, "total_tokens": 435}, + "system_fingerprint": None, + }, + status=200, + ) - with requests_mock.Mocker() as m: - m.post(url, json=response) - # Define the predictor. - generate_answer = dspy.Predict(BasicQA) + # Define the predictor. + generate_answer = dspy.Predict(BasicQA) - # Call the predictor on a particular input. - question = "What's the capital of the United States?" # noqa: E501 - pred = generate_answer(question=question) + # Call the predictor on a particular input. + question = "What's the capital of the United States?" # noqa: E501 + pred = generate_answer(question=question) assert pred.answer == "Washington DC" spans = in_memory_span_exporter.get_finished_spans() assert len(spans) == 2 # 1 for the wrapping Signature, 1 for the OpenAI call lm_span = spans[0] chain_span = spans[1] - assert chain_span.name == "BasicQA.forward" + assert chain_span.name == "Predict(BasicQA).forward" assert lm_span.name == "GPT3.request" - assert question in lm_span.attributes[SpanAttributes.INPUT_VALUE] # type: ignore + assert question in lm_span.attributes[INPUT_VALUE] # type: ignore + + +@responses.activate +def test_rag_module( + in_memory_span_exporter: InMemorySpanExporter, +) -> None: + class BasicQA(dspy.Signature): # type: ignore + """Answer questions with short factoid answers.""" + + question = dspy.InputField() + answer = dspy.OutputField(desc="often between 1 and 5 words") + + class RAG(dspy.Module): # type: ignore + """ + Performs RAG on a corpus of data. + """ + + def __init__(self, num_passages: int = 3) -> None: + super().__init__() + self.retrieve = dspy.Retrieve(k=num_passages) + self.generate_answer = dspy.ChainOfThought(BasicQA) + + def forward(self, question: str) -> dspy.Prediction: + context = self.retrieve(question).passages + prediction = self.generate_answer(context=context, question=question) + return dspy.Prediction(context=context, answer=prediction.answer) + + turbo = dspy.OpenAI(api_key="jk-fake-key", model_type="chat") + colbertv2_url = "https://www.examplecolbertv2service.com/wiki17_abstracts" + colbertv2 = dspy.ColBERTv2(url=colbertv2_url) + dspy.settings.configure(lm=turbo, rm=colbertv2) + + # Mock the request to the remote ColBERTv2 service. + responses.add( + method=responses.GET, + url=colbertv2_url, + json={ + "topk": [ + { + "text": "first retrieved document text", + "pid": 1918771, + "rank": 1, + "score": 26.81817626953125, + "prob": 0.7290767171685155, + "long_text": "first retrieved document long text", + }, + { + "text": "second retrieved document text", + "pid": 3377468, + "rank": 2, + "score": 25.304840087890625, + "prob": 0.16052389034616518, + "long_text": "second retrieved document long text", + }, + { + "text": "third retrieved document text", + "pid": 953799, + "rank": 3, + "score": 24.93050193786621, + "prob": 0.11039939248531924, + "long_text": "third retrieved document long text", + }, + ], + "latency": 84.43140983581543, + }, + status=200, + ) + + # Mock out the OpenAI API. + responses.add( + method=responses.POST, + url="https://api.openai.com/v1/chat/completions", + json={ + "id": "chatcmpl-8kKarJQUyeuFeRsj18o6TWrxoP2zs", + "object": "chat.completion", + "created": 1706052941, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Washington, D.C.", + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 39, "completion_tokens": 396, "total_tokens": 435}, + "system_fingerprint": None, + }, + status=200, + ) + + rag = RAG() + question = "What's the capital of the United States?" + prediction = rag(question=question) + + assert prediction.answer == "Washington, D.C." + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 6 + + span = spans[0] + assert (attributes := span.attributes) is not None + assert span.name == "ColBERTv2.__call__" + assert attributes[OPENINFERENCE_SPAN_KIND] == RETRIEVER.value + assert isinstance(input_value := attributes[INPUT_VALUE], str) + assert json.loads(input_value) == { + "k": 3, + "query": "What's the capital of the United States?", + } + assert attributes[INPUT_MIME_TYPE] == JSON.value + assert isinstance( + attributes[f"{RETRIEVAL_DOCUMENTS}.0.{DOCUMENT_ID}"], + int, + ) + assert isinstance( + attributes[f"{RETRIEVAL_DOCUMENTS}.1.{DOCUMENT_ID}"], + int, + ) + assert isinstance( + attributes[f"{RETRIEVAL_DOCUMENTS}.2.{DOCUMENT_ID}"], + int, + ) + assert ( + attributes[f"{RETRIEVAL_DOCUMENTS}.0.{DOCUMENT_CONTENT}"] == "first retrieved document text" + ) + assert ( + attributes[f"{RETRIEVAL_DOCUMENTS}.1.{DOCUMENT_CONTENT}"] + == "second retrieved document text" + ) + assert ( + attributes[f"{RETRIEVAL_DOCUMENTS}.2.{DOCUMENT_CONTENT}"] == "third retrieved document text" + ) + assert isinstance( + attributes[f"{RETRIEVAL_DOCUMENTS}.0.{DOCUMENT_SCORE}"], + float, + ) + assert isinstance( + attributes[f"{RETRIEVAL_DOCUMENTS}.1.{DOCUMENT_SCORE}"], + float, + ) + assert isinstance( + attributes[f"{RETRIEVAL_DOCUMENTS}.2.{DOCUMENT_SCORE}"], + float, + ) + + span = spans[1] + assert (attributes := span.attributes) is not None + assert span.name == "Retrieve.forward" + assert attributes[OPENINFERENCE_SPAN_KIND] == RETRIEVER.value + assert isinstance(input_value := attributes[INPUT_VALUE], str) and json.loads(input_value) == { + "query_or_queries": "What's the capital of the United States?" + } + assert attributes[INPUT_MIME_TYPE] == JSON.value + assert ( + attributes[f"{RETRIEVAL_DOCUMENTS}.0.{DOCUMENT_CONTENT}"] == "first retrieved document text" + ) + assert ( + attributes[f"{RETRIEVAL_DOCUMENTS}.1.{DOCUMENT_CONTENT}"] + == "second retrieved document text" + ) + assert ( + attributes[f"{RETRIEVAL_DOCUMENTS}.2.{DOCUMENT_CONTENT}"] == "third retrieved document text" + ) + + span = spans[2] + assert (attributes := span.attributes) is not None + assert span.name == "GPT3.request" + assert attributes[OPENINFERENCE_SPAN_KIND] == LLM.value + + span = spans[3] + assert (attributes := span.attributes) is not None + assert span.name == "GPT3.request" + assert attributes[OPENINFERENCE_SPAN_KIND] == LLM.value + + span = spans[4] + assert (attributes := span.attributes) is not None + assert span.name == "ChainOfThought(BasicQA).forward" + assert attributes[OPENINFERENCE_SPAN_KIND] == CHAIN.value + assert isinstance(input_value := attributes[INPUT_VALUE], str) + input_value_data = json.loads(input_value) + assert set(input_value_data.keys()) == {"context", "question"} + assert question == input_value_data["question"] + assert isinstance(output_value := attributes[OUTPUT_VALUE], str) + output_value_data = json.loads(output_value) + assert set(output_value_data.keys()) == {"answer"} + assert output_value_data["answer"] == "Washington, D.C." + + span = spans[5] + assert (attributes := span.attributes) is not None + assert span.name == "RAG.forward" + assert attributes[OPENINFERENCE_SPAN_KIND] == CHAIN.value + assert isinstance(input_value := attributes[INPUT_VALUE], str) + assert json.loads(input_value) == { + "question": question, + } + assert attributes[INPUT_MIME_TYPE] == JSON.value + assert isinstance(output_value := attributes[OUTPUT_VALUE], str) + assert "Washington, D.C." in output_value + assert attributes[OUTPUT_MIME_TYPE] == JSON.value + + +DOCUMENT_CONTENT = DocumentAttributes.DOCUMENT_CONTENT +DOCUMENT_ID = DocumentAttributes.DOCUMENT_ID +DOCUMENT_SCORE = DocumentAttributes.DOCUMENT_SCORE +INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE +INPUT_VALUE = SpanAttributes.INPUT_VALUE +OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND +OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE +OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE +RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS + +CHAIN = OpenInferenceSpanKindValues.CHAIN +LLM = OpenInferenceSpanKindValues.LLM +RETRIEVER = OpenInferenceSpanKindValues.RETRIEVER +EMBEDDING = OpenInferenceSpanKindValues.EMBEDDING + +JSON = OpenInferenceMimeTypeValues.JSON