diff --git a/docs/docs/examples/embeddings/oci_data_science.ipynb b/docs/docs/examples/embeddings/oci_data_science.ipynb new file mode 100644 index 0000000000000..631d474fb58f7 --- /dev/null +++ b/docs/docs/examples/embeddings/oci_data_science.ipynb @@ -0,0 +1,174 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "6d1ca9ac", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "id": "9e3a8796-edc8-43f2-94ad-fe4fb20d70ed", + "metadata": {}, + "source": [ + "# Oracle Cloud Infrastructure (OCI) Data Science Service\n", + "\n", + "Oracle Cloud Infrastructure (OCI) [Data Science](https://www.oracle.com/artificial-intelligence/data-science) is a fully managed, serverless platform for data science teams to build, train, and manage machine learning models in Oracle Cloud Infrastructure.\n", + "\n", + "It offers [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm), which can be used to deploy embedding models in OCI Data Science. AI Quick Actions target users who want to quickly leverage the capabilities of AI. They aim to expand the reach of foundation models to a broader set of users by providing a streamlined, code-free, and efficient environment for working with foundation models. AI Quick Actions can be accessed from the Data Science Notebook.\n", + "\n", + "Detailed documentation on how to deploy embedding models in OCI Data Science using AI Quick Actions is available [here](https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/model-deployment-tips.md) and [here](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions-model-deploy.htm).\n", + "\n", + "This notebook explains how to use OCI's Data Science embedding models with LlamaIndex." + ] + }, + { + "cell_type": "markdown", + "id": "3802e8c4", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb0dd8c9", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index-embeddings-oci-data-science" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "544d49f9", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install llama-index" + ] + }, + { + "cell_type": "markdown", + "id": "c2921307", + "metadata": {}, + "source": [ + "You will also need to install the [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/index.html) SDK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "378d5179", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U oracle-ads" + ] + }, + { + "cell_type": "markdown", + "id": "423605c6", + "metadata": {}, + "source": [ + "## Authentication\n", + "\n", + "The authentication methods supported for LlamaIndex are equivalent to those used with other OCI services and follow the standard SDK authentication methods, specifically API Key, session token, instance principal, and resource principal. More details can be found [here](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html). Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm) to access the OCI Data Science Model Deployment endpoint. The [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/index.html) helps to simplify the authentication within OCI Data Science." + ] + }, + { + "cell_type": "markdown", + "id": "03d4024a", + "metadata": {}, + "source": [ + "## Basic Usage\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60be18ae-c957-4ac2-a58a-0652e18ee6d6", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.embeddings.oci_data_science import OCIDataScienceEmbedding\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "embedding = OCIDataScienceEmbedding(\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "\n", + "\n", + "e1 = embeddings.get_text_embedding(\"This is a test document\")\n", + "print(e1)\n", + "\n", + "e2 = embeddings.get_text_embedding_batch(\n", + " [\"This is a test document\", \"This is another test document\"]\n", + ")\n", + "print(e2)" + ] + }, + { + "cell_type": "markdown", + "id": "170e1ad7", + "metadata": {}, + "source": [ + "## Async" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbc3cba9", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.embeddings.oci_data_science import OCIDataScienceEmbedding\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "embedding = OCIDataScienceEmbedding(\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "\n", + "e1 = await embeddings.aget_text_embedding(\"This is a test document\")\n", + "print(e1)\n", + "\n", + "e2 = await embeddings.aget_text_embedding_batch(\n", + " [\"This is a test document\", \"This is another test document\"]\n", + ")\n", + "print(e2)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/.gitignore b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/.gitignore new file mode 100644 index 0000000000000..990c18de22908 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/BUILD new file mode 100644 index 0000000000000..0896ca890d8bf --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/Makefile b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/Makefile new file mode 100644 index 0000000000000..b9eab05aa3706 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/README.md b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/README.md new file mode 100644 index 0000000000000..7457116f15881 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/README.md @@ -0,0 +1,70 @@ +# LlamaIndex Embeddings Integration: Oracle Cloud Infrastructure (OCI) Data Science Service + +Oracle Cloud Infrastructure (OCI) [Data Science](https://www.oracle.com/artificial-intelligence/data-science) is a fully managed, serverless platform for data science teams to build, train, and manage machine learning models in Oracle Cloud Infrastructure. + +It offers [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm), which can be used to deploy embedding models in OCI Data Science. AI Quick Actions target users who want to quickly leverage the capabilities of AI. They aim to expand the reach of foundation models to a broader set of users by providing a streamlined, code-free, and efficient environment for working with foundation models. AI Quick Actions can be accessed from the Data Science Notebook. + +Detailed documentation on how to deploy embedding models in OCI Data Science using AI Quick Actions is available [here](https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/model-deployment-tips.md) and [here](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions-model-deploy.htm). + +## Installation + +Install the required packages: + +```bash +pip install oracle-ads llama-index-core llama-index-embeddings-oci-data-science + +``` + +The [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/index.html) is required to simplify the authentication within OCI Data Science. + +## Authentication + +The authentication methods supported for LlamaIndex are equivalent to those used with other OCI services and follow the standard SDK authentication methods, specifically API Key, session token, instance principal, and resource principal. More details can be found [here](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html). Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm) to access the OCI Data Science Model Deployment endpoint. + +## Usage + +```bash +import ads +from llama_index.embeddings.oci_data_science import OCIDataScienceEmbedding + +ads.set_auth(auth="security_token", profile="") + +embedding = OCIDataScienceEmbedding( + endpoint="https:///predict", +) + +e1 = embeddings.get_text_embedding("This is a test document") +print(e1) + +e2 = embeddings.get_text_embedding_batch([ + "This is a test document", + "This is another test document" + ]) +print(e2) +``` + +## Async + +```bash +import ads +from llama_index.embeddings.oci_data_science import OCIDataScienceEmbedding + +ads.set_auth(auth="security_token", profile="") + +embedding = OCIDataScienceEmbedding( + endpoint="https:///predict", +) + +e1 = await embeddings.aget_text_embedding("This is a test document") +print(e1) + +e2 = await embeddings.aget_text_embedding_batch([ + "This is a test document", + "This is another test document" + ]) +print(e2) +``` + +## More examples + +https://docs.llamaindex.ai/en/stable/examples/embeddings/oci_data_science/ diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/__init__.py new file mode 100644 index 0000000000000..6710096f3ab6f --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/__init__.py @@ -0,0 +1,4 @@ +from llama_index.embeddings.oci_data_science.base import OCIDataScienceEmbedding + + +__all__ = ["OCIDataScienceEmbedding"] diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/base.py new file mode 100644 index 0000000000000..1a81462aabd49 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/base.py @@ -0,0 +1,322 @@ +from typing import Any, Dict, List, Optional, Union + +from llama_index.core.base.embeddings.base import ( + DEFAULT_EMBED_BATCH_SIZE, + BaseEmbedding, +) +from llama_index.core.bridge.pydantic import Field, PrivateAttr, model_validator +from llama_index.core.callbacks.base import CallbackManager +from llama_index.embeddings.oci_data_science.client import AsyncClient, Client + +# from llama_index.embeddings.oci_data_science.utils import _validate_dependency + +DEFAULT_MODEL = "odsc-embeddings" +DEFAULT_TIMEOUT = 120 +DEFAULT_MAX_RETRIES = 5 + + +class OCIDataScienceEmbedding(BaseEmbedding): + """Embedding class for OCI Data Science models. + + This class provides methods to generate embeddings using models deployed on + Oracle Cloud Infrastructure (OCI) Data Science. It supports both synchronous + and asynchronous requests and handles authentication, batching, and other + configurations. + + Setup: + Install the required packages: + ```bash + pip install -U oracle-ads llama-index-embeddings-oci-data-science + ``` + + Configure authentication using `ads.set_auth()`. For example, to use OCI + Resource Principal for authentication: + ```python + import ads + ads.set_auth("resource_principal") + ``` + + For more details on authentication, see: + https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html + + Ensure you have the required policies to access the OCI Data Science Model + Deployment endpoint: + https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm + + To learn more about deploying LLM models in OCI Data Science, see: + https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions-model-deploy.htm + + Examples: + Basic Usage: + ```python + import ads + from llama_index.embeddings.oci_data_science import OCIDataScienceEmbedding + + ads.set_auth(auth="security_token", profile="OC1") + + embeddings = OCIDataScienceEmbedding( + endpoint="https:///predict", + ) + + e1 = embeddings.get_text_embedding("This is a test document") + print(e1) + + e2 = embeddings.get_text_embedding_batch([ + "This is a test document", + "This is another test document" + ]) + print(e2) + ``` + + Asynchronous Usage: + ```python + import ads + import asyncio + from llama_index.embeddings.oci_data_science import OCIDataScienceEmbedding + + ads.set_auth(auth="security_token", profile="OC1") + + embeddings = OCIDataScienceEmbedding( + endpoint="https:///predict", + ) + + async def async_embedding(): + e1 = await embeddings.aget_query_embedding("This is a test document") + print(e1) + + asyncio.run(async_embedding()) + ``` + + Attributes: + endpoint (str): The URI of the endpoint from the deployed model. + auth (Dict[str, Any]): The authentication dictionary used for OCI API requests. + model_name (str): The name of the OCI Data Science embedding model. + embed_batch_size (int): The batch size for embedding calls. + additional_kwargs (Dict[str, Any]): Additional keyword arguments for the OCI Data Science AI request. + default_headers (Dict[str, str]): The default headers for API requests. + """ + + endpoint: str = Field( + default=None, description="The URI of the endpoint from the deployed model." + ) + + auth: Union[Dict[str, Any], None] = Field( + default_factory=dict, + exclude=True, + description=( + "The authentication dictionary used for OCI API requests. " + "If not provided, it will be autogenerated based on environment variables." + ), + ) + model_name: Optional[str] = Field( + default=DEFAULT_MODEL, + description="The name of the OCI Data Science embedding model to use.", + ) + + embed_batch_size: int = Field( + default=DEFAULT_EMBED_BATCH_SIZE, + description="The batch size for embedding calls.", + gt=0, + le=2048, + ) + + max_retries: int = Field( + default=DEFAULT_MAX_RETRIES, + description="The maximum number of API retries.", + ge=0, + ) + + timeout: float = Field( + default=DEFAULT_TIMEOUT, description="The timeout to use in seconds.", ge=0 + ) + + additional_kwargs: Optional[Dict[str, Any]] = Field( + default_factory=dict, + description="Additional keyword arguments for the OCI Data Science AI request.", + ) + + default_headers: Optional[Dict[str, str]] = Field( + default_factory=dict, description="The default headers for API requests." + ) + + _client: Client = PrivateAttr() + _async_client: AsyncClient = PrivateAttr() + + def __init__( + self, + endpoint: str, + model_name: Optional[str] = DEFAULT_MODEL, + auth: Dict[str, Any] = None, + timeout: Optional[float] = DEFAULT_TIMEOUT, + max_retries: Optional[int] = DEFAULT_MAX_RETRIES, + embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, + additional_kwargs: Optional[Dict[str, Any]] = None, + default_headers: Optional[Dict[str, str]] = None, + callback_manager: Optional[CallbackManager] = None, + **kwargs: Any + ) -> None: + """Initialize the OCIDataScienceEmbedding instance. + + Args: + endpoint (str): The URI of the endpoint from the deployed model. + model_name (Optional[str]): The name of the OCI Data Science embedding model to use. Defaults to "odsc-embeddings". + auth (Optional[Dict[str, Any]]): The authentication dictionary for OCI API requests. Defaults to None. + timeout (Optional[float]): The timeout setting for the HTTP request in seconds. Defaults to 120. + max_retries (Optional[int]): The maximum number of retry attempts for the request. Defaults to 5. + embed_batch_size (int): The batch size for embedding calls. Defaults to DEFAULT_EMBED_BATCH_SIZE. + additional_kwargs (Optional[Dict[str, Any]]): Additional arguments for the OCI Data Science AI request. Defaults to None. + default_headers (Optional[Dict[str, str]]): The default headers for API requests. Defaults to None. + callback_manager (Optional[CallbackManager]): A callback manager for handling events during embedding operations. Defaults to None. + **kwargs: Additional keyword arguments. + """ + super().__init__( + model_name=model_name, + endpoint=endpoint, + auth=auth, + embed_batch_size=embed_batch_size, + timeout=timeout, + max_retries=max_retries, + additional_kwargs=additional_kwargs or {}, + default_headers=default_headers or {}, + callback_manager=callback_manager, + **kwargs + ) + + @model_validator(mode="before") + # @_validate_dependency + def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate the environment and dependencies before initialization. + + Args: + values (Dict[str, Any]): The values passed to the model. + + Returns: + Dict[str, Any]: The validated values. + + Raises: + ImportError: If required dependencies are missing. + """ + return values + + @property + def client(self) -> Client: + """Return the synchronous client instance. + + Returns: + Client: The synchronous client for interacting with the OCI Data Science Model Deployment endpoint. + """ + if not hasattr(self, "_client") or self._client is None: + self._client = Client( + endpoint=self.endpoint, + auth=self.auth, + retries=self.max_retries, + timeout=self.timeout, + ) + return self._client + + @property + def async_client(self) -> AsyncClient: + """Return the asynchronous client instance. + + Returns: + AsyncClient: The asynchronous client for interacting with the OCI Data Science Model Deployment endpoint. + """ + if not hasattr(self, "_async_client") or self._async_client is None: + self._async_client = AsyncClient( + endpoint=self.endpoint, + auth=self.auth, + retries=self.max_retries, + timeout=self.timeout, + ) + return self._async_client + + @classmethod + def class_name(cls) -> str: + """Get the class name. + + Returns: + str: The name of the class. + """ + return "OCIDataScienceEmbedding" + + def _get_query_embedding(self, query: str) -> List[float]: + """Generate an embedding for a query string. + + Args: + query (str): The query string for which to generate an embedding. + + Returns: + List[float]: The embedding vector for the query. + """ + return self.client.embeddings( + input=query, payload=self.additional_kwargs, headers=self.default_headers + )["data"][0]["embedding"] + + def _get_text_embedding(self, text: str) -> List[float]: + """Generate an embedding for a text string. + + Args: + text (str): The text string for which to generate an embedding. + + Returns: + List[float]: The embedding vector for the text. + """ + return self.client.embeddings( + input=text, payload=self.additional_kwargs, headers=self.default_headers + )["data"][0]["embedding"] + + async def _aget_text_embedding(self, text: str) -> List[float]: + """Asynchronously generate an embedding for a text string. + + Args: + text (str): The text string for which to generate an embedding. + + Returns: + List[float]: The embedding vector for the text. + """ + response = await self.async_client.embeddings( + input=text, payload=self.additional_kwargs, headers=self.default_headers + ) + return response["data"][0]["embedding"] + + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for a list of text strings. + + Args: + texts (List[str]): A list of text strings for which to generate embeddings. + + Returns: + List[List[float]]: A list of embedding vectors corresponding to the input texts. + """ + response = self.client.embeddings( + input=texts, payload=self.additional_kwargs, headers=self.default_headers + ) + return [raw["embedding"] for raw in response["data"]] + + async def _aget_query_embedding(self, query: str) -> List[float]: + """Asynchronously generate an embedding for a query string. + + Args: + query (str): The query string for which to generate an embedding. + + Returns: + List[float]: The embedding vector for the query. + """ + response = await self.async_client.embeddings( + input=query, payload=self.additional_kwargs, headers=self.default_headers + ) + return response["data"][0]["embedding"] + + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Asynchronously generate embeddings for a list of text strings. + + Args: + texts (List[str]): A list of text strings for which to generate embeddings. + + Returns: + List[List[float]]: A list of embedding vectors corresponding to the input texts. + """ + response = await self.async_client.embeddings( + input=texts, payload=self.additional_kwargs, headers=self.default_headers + ) + return [raw["embedding"] for raw in response["data"]] diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/client.py b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/client.py new file mode 100644 index 0000000000000..5d1322b0e5cbe --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/client.py @@ -0,0 +1,550 @@ +import asyncio +import functools +import logging +from abc import ABC +from types import TracebackType +from typing import ( + Any, + AnyStr, + Callable, + Dict, + Iterator, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) + +import httpx +import oci +import requests +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception, + stop_after_attempt, + stop_after_delay, + wait_exponential, + wait_random_exponential, +) + +DEFAULT_RETRIES = 3 +DEFAULT_BACKOFF_FACTOR = 3 +TIMEOUT = 600 # Timeout in seconds +STATUS_FORCE_LIST = [429, 500, 502, 503, 504] +DEFAULT_ENCODING = "utf-8" + +_T = TypeVar("_T", bound="BaseClient") + +logger = logging.getLogger(__name__) + + +class OCIAuth(httpx.Auth): + """Custom HTTPX authentication class that uses the OCI Signer for request signing. + + This class implements the HTTPX authentication interface, enabling it to sign outgoing HTTP requests + using an Oracle Cloud Infrastructure (OCI) Signer. + + Attributes: + signer (oci.signer.Signer): The OCI signer used to sign requests. + """ + + def __init__(self, signer: oci.signer.Signer): + """Initialize the OCIAuth instance. + + Args: + signer (oci.signer.Signer): The OCI signer to use for signing requests. + """ + self.signer = signer + + def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]: + """The authentication flow that signs the HTTPX request using the OCI signer. + + This method is called by HTTPX to sign each request before it is sent. + + Args: + request (httpx.Request): The outgoing HTTPX request to be signed. + + Yields: + httpx.Request: The signed HTTPX request. + """ + # Create a requests.Request object from the HTTPX request + req = requests.Request( + method=request.method, + url=str(request.url), + headers=dict(request.headers), + data=request.content, + ) + prepared_request = req.prepare() + + # Sign the request using the OCI Signer + self.signer.do_request_sign(prepared_request) + + # Update the original HTTPX request with the signed headers + request.headers.update(prepared_request.headers) + + # Proceed with the request + yield request + + +class ExtendedRequestException(Exception): + """Custom exception for handling request errors with additional context. + + Attributes: + original_exception (Exception): The original exception that caused the error. + response_text (str): The text of the response received from the request, if available. + """ + + def __init__(self, message: str, original_exception: Exception, response_text: str): + """Initialize the ExtendedRequestException. + + Args: + message (str): The error message associated with the exception. + original_exception (Exception): The original exception that caused the error. + response_text (str): The text of the response received from the request, if available. + """ + super().__init__(message) + self.original_exception = original_exception + self.response_text = response_text + + +def _should_retry_exception(e: ExtendedRequestException) -> bool: + """Determine whether the exception should trigger a retry. + + This function checks if the exception is of a type that should cause the request to be retried, + based on the status code or the type of exception. + + Args: + e (ExtendedRequestException): The exception raised during the request. + + Returns: + bool: True if the exception should trigger a retry, False otherwise. + """ + original_exception = e.original_exception if hasattr(e, "original_exception") else e + if isinstance(original_exception, httpx.HTTPStatusError): + return original_exception.response.status_code in STATUS_FORCE_LIST + elif isinstance(original_exception, httpx.RequestError): + return True + return False + + +def _create_retry_decorator( + max_retries: int, + backoff_factor: float, + random_exponential: bool = False, + stop_after_delay_seconds: Optional[float] = None, + min_seconds: float = 0, + max_seconds: float = 60, +) -> Callable[[Any], Any]: + """Create a tenacity retry decorator with the specified configuration. + + This function sets up a retry strategy using the tenacity library, which can be applied to functions + to automatically retry on failure. + + Args: + max_retries (int): The maximum number of retry attempts. + backoff_factor (float): The backoff factor for calculating retry delays. + random_exponential (bool, optional): Whether to use random exponential backoff. Defaults to False. + stop_after_delay_seconds (Optional[float], optional): Maximum total time in seconds to retry. + If None, there is no time limit. Defaults to None. + min_seconds (float, optional): Minimum wait time between retries in seconds. Defaults to 0. + max_seconds (float, optional): Maximum wait time between retries in seconds. Defaults to 60. + + Returns: + Callable[[Any], Any]: A tenacity retry decorator configured with the specified strategy. + """ + wait_strategy = ( + wait_random_exponential(min=min_seconds, max=max_seconds) + if random_exponential + else wait_exponential( + multiplier=backoff_factor, min=min_seconds, max=max_seconds + ) + ) + + stop_strategy = stop_after_attempt(max_retries) + if stop_after_delay_seconds is not None: + stop_strategy = stop_strategy | stop_after_delay(stop_after_delay_seconds) + + retry_strategy = retry_if_exception(_should_retry_exception) + return retry( + wait=wait_strategy, + stop=stop_strategy, + retry=retry_strategy, + reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def _retry_decorator(f: Callable) -> Callable: + """Decorator to apply retry logic to a function using tenacity. + + This decorator applies a retry strategy to the decorated function, retrying it according + to the configured backoff and retry settings. + + Args: + f (Callable): The function to be decorated. + + Returns: + Callable: The decorated function with retry logic applied. + """ + + @functools.wraps(f) + def wrapper(self, *args: Any, **kwargs: Any): + retries = getattr(self, "retries", DEFAULT_RETRIES) + if retries <= 0: + return f(self, *args, **kwargs) + backoff_factor = getattr(self, "backoff_factor", DEFAULT_BACKOFF_FACTOR) + retry_func = _create_retry_decorator( + max_retries=retries, + backoff_factor=backoff_factor, + random_exponential=False, + stop_after_delay_seconds=getattr(self, "timeout", TIMEOUT), + min_seconds=0, + max_seconds=60, + ) + + return retry_func(f)(self, *args, **kwargs) + + return wrapper + + +class BaseClient(ABC): + """Abstract base class for HTTP clients invoking models with retry logic. + + This class provides common functionality for synchronous and asynchronous clients, + including request preparation, authentication, and retry handling. + + Attributes: + endpoint (str): The URL endpoint to send the request. + auth (httpx.Auth): The authentication signer for the requests. + retries (int): The number of retry attempts for the request. + backoff_factor (float): The factor to determine the delay between retries. + timeout (Union[float, Tuple[float, float]]): The timeout setting for the HTTP request. + kwargs (Dict[str, Any]): Additional keyword arguments. + """ + + def __init__( + self, + endpoint: str, + auth: Optional[Any] = None, + retries: Optional[int] = DEFAULT_RETRIES, + backoff_factor: Optional[float] = DEFAULT_BACKOFF_FACTOR, + timeout: Optional[Union[float, Tuple[float, float]]] = None, + **kwargs: Any, + ) -> None: + """Initialize the BaseClient. + + Args: + endpoint (str): The URL endpoint to send the request. + auth (Optional[Any]): The authentication signer for the requests. If None, the default signer is used. + retries (Optional[int]): The number of retry attempts for the request. Defaults to DEFAULT_RETRIES. + backoff_factor (Optional[float]): The factor to determine the delay between retries. Defaults to DEFAULT_BACKOFF_FACTOR. + timeout (Optional[Union[float, Tuple[float, float]]]): The timeout setting for the HTTP request in seconds. + Can be a single float for total timeout, or a tuple (connect_timeout, read_timeout). Defaults to TIMEOUT. + **kwargs: Additional keyword arguments. + """ + self.endpoint = endpoint + self.retries = retries or DEFAULT_RETRIES + self.backoff_factor = backoff_factor or DEFAULT_BACKOFF_FACTOR + self.timeout = timeout or TIMEOUT + self.kwargs = kwargs + + # Use default signer from ADS if `auth` if auth not provided + if not auth: + try: + from ads.common import auth as authutil + + auth = auth or authutil.default_signer() + except ImportError as ex: + raise ImportError( + "The authentication signer for the requests was not provided. " + "Use `auth` attribute to provide the signer. " + "The authentication methods supported for LlamaIndex are equivalent to those " + "used with other OCI services and follow the standard SDK authentication methods, " + "specifically API Key, session token, instance principal, and resource principal. " + "For more details, refer to the documentation: " + "`https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html`. " + "Alternatively you can use the `oracle-ads` package. " + "Please install it with `pip install oracle-ads` and follow the example provided here: " + "`https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html#authentication`." + ) from ex + + # Validate auth object + if not callable(auth.get("signer")): + raise ValueError("Auth object must have a 'signer' callable attribute.") + self.auth = OCIAuth(auth["signer"]) + + logger.debug( + f"Initialized {self.__class__.__name__} with endpoint={self.endpoint}, " + f"retries={self.retries}, backoff_factor={self.backoff_factor}, timeout={self.timeout}" + ) + + def _prepare_headers( + self, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, str]: + """Construct and return the headers for a request. + + This method merges any provided headers with the default headers. + + Args: + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Returns: + Dict[str, str]: The prepared headers. + """ + default_headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + if headers: + default_headers.update(headers) + + logger.debug(f"Prepared headers: {default_headers}") + return default_headers + + +class Client(BaseClient): + """Synchronous HTTP client for invoking models with retry logic. + + This client sends HTTP requests to a specified endpoint and handles retries, timeouts, and authentication. + + Attributes: + _client (httpx.Client): The underlying HTTPX client used for sending requests. + """ + + def __init__(self, *args, **kwargs) -> None: + """Initialize the Client. + + Args: + *args: Positional arguments forwarded to BaseClient. + **kwargs: Keyword arguments forwarded to BaseClient. + """ + super().__init__(*args, **kwargs) + self._client = httpx.Client(timeout=self.timeout) + + def is_closed(self) -> bool: + """Check if the underlying HTTPX client is closed. + + Returns: + bool: True if the client is closed, False otherwise. + """ + return self._client.is_closed + + def close(self) -> None: + """Close the underlying HTTPX client. + + The client will not be usable after this method is called. + """ + self._client.close() + + def __enter__(self: _T) -> _T: # noqa: PYI019 + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]] = None, + exc: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> None: + self.close() + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + @_retry_decorator + def _request( + self, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + """Send a POST request to the configured endpoint with retry and error handling. + + This method handles the HTTP request, including retries on failure, and returns the JSON response. + + Args: + payload (Dict[str, Any]): Parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Returns: + Dict[str, Any]: The decoded JSON response from the server. + + Raises: + ExtendedRequestException: Raised when the request fails after retries. + """ + logger.debug(f"Starting synchronous request with payload: {payload}") + try: + response = self._client.post( + self.endpoint, + headers=self._prepare_headers(headers=headers), + auth=self.auth, + json=payload, + ) + logger.debug(f"Received response with status code: {response.status_code}") + response.raise_for_status() + json_response = response.json() + logger.debug(f"Response JSON: {json_response}") + return json_response + except Exception as e: + last_exception_text = ( + e.response.text if hasattr(e, "response") and e.response else str(e) + ) + logger.error( + f"Request failed. Error: {e!s}. Details: {last_exception_text}" + ) + raise ExtendedRequestException( + f"Request failed: {e!s}. Details: {last_exception_text}", + e, + last_exception_text, + ) from e + + def embeddings( + self, + input: Union[str, Sequence[AnyStr]] = "", + payload: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: + """Generate embeddings by sending a request to the endpoint. + + Args: + input (Union[str, Sequence[AnyStr]], optional): The input text or sequence of texts for which to generate embeddings. + Defaults to "". + payload (Optional[Dict[str, Any]], optional): Additional parameters to include in the request payload. + Defaults to None. + headers (Optional[Dict[str, str]], optional): HTTP headers to include in the request. + Defaults to None. + + Returns: + Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: The server's response, typically including the generated embeddings. + """ + logger.debug(f"Generating embeddings with input: {input}, payload: {payload}") + payload = {**(payload or {}), "input": input} + return self._request(payload=payload, headers=headers) + + +class AsyncClient(BaseClient): + """Asynchronous HTTP client for invoking models with retry logic. + + This client sends asynchronous HTTP requests to a specified endpoint and handles retries, + timeouts, and authentication. + + Attributes: + _client (httpx.AsyncClient): The underlying HTTPX async client used for sending requests. + """ + + def __init__(self, *args, **kwargs) -> None: + """Initialize the AsyncClient. + + Args: + *args: Positional arguments forwarded to BaseClient. + **kwargs: Keyword arguments forwarded to BaseClient. + """ + super().__init__(*args, **kwargs) + self._client = httpx.AsyncClient(timeout=self.timeout) + + def is_closed(self) -> bool: + """Check if the underlying HTTPX client is closed. + + Returns: + bool: True if the client is closed, False otherwise. + """ + return self._client.is_closed + + async def close(self) -> None: + """Close the underlying HTTPX client. + + The client will not be usable after this method is called. + """ + await self._client.aclose() + + async def __aenter__(self: _T) -> _T: # noqa: PYI019 + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]] = None, + exc: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> None: + await self.close() + + def __del__(self) -> None: + try: + if not self._client.is_closed: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + except Exception: + pass + + @_retry_decorator + async def _request( + self, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + """Send an asynchronous POST request to the configured endpoint with retry and error handling. + + This method handles the HTTP request asynchronously, including retries on failure, + and returns the JSON response. + + Args: + payload (Dict[str, Any]): Parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Returns: + Dict[str, Any]: The decoded JSON response from the server. + + Raises: + ExtendedRequestException: Raised when the request fails after retries. + """ + logger.debug(f"Starting asynchronous request with payload: {payload}") + try: + response = await self._client.post( + self.endpoint, + headers=self._prepare_headers(headers=headers), + auth=self.auth, + json=payload, + ) + logger.debug(f"Received response with status code: {response.status_code}") + response.raise_for_status() + json_response = response.json() + logger.debug(f"Response JSON: {json_response}") + return json_response + except Exception as e: + last_exception_text = ( + e.response.text if hasattr(e, "response") and e.response else str(e) + ) + logger.error( + f"Request failed. Error: {e!s}. Details: {last_exception_text}" + ) + raise ExtendedRequestException( + f"Request failed: {e!s}. Details: {last_exception_text}", + e, + last_exception_text, + ) from e + + async def embeddings( + self, + input: Union[str, Sequence[AnyStr]] = "", + payload: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: + """Generate embeddings asynchronously by sending a request to the endpoint. + + Args: + input (Union[str, Sequence[AnyStr]], optional): The input text or sequence of texts for which to generate embeddings. + Defaults to "". + payload (Optional[Dict[str, Any]], optional): Additional parameters to include in the request payload. + Defaults to None. + headers (Optional[Dict[str, str]], optional): HTTP headers to include in the request. + Defaults to None. + + Returns: + Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: The server's response, typically including the generated embeddings. + """ + logger.debug(f"Generating embeddings with input: {input}, payload: {payload}") + payload = {**(payload or {}), "input": input} + return await self._request(payload=payload, headers=headers) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/utils.py b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/utils.py new file mode 100644 index 0000000000000..809cfb2145b76 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/llama_index/embeddings/oci_data_science/utils.py @@ -0,0 +1,67 @@ +import logging +from functools import wraps +from typing import Any, Callable + +from packaging import version + +MIN_ADS_VERSION = "2.12.9" + +logger = logging.getLogger(__name__) + + +class UnsupportedOracleAdsVersionError(Exception): + """Custom exception for unsupported `oracle-ads` versions. + + Attributes: + current_version (str): The installed version of `oracle-ads`. + required_version (str): The minimum required version of `oracle-ads`. + """ + + def __init__(self, current_version: str, required_version: str): + """Initialize the UnsupportedOracleAdsVersionError. + + Args: + current_version (str): The currently installed version of `oracle-ads`. + required_version (str): The minimum required version of `oracle-ads`. + """ + super().__init__( + f"The `oracle-ads` version {current_version} currently installed is incompatible with " + "the `llama-index-llms-oci-data-science` version in use. To resolve this issue, " + f"please upgrade to `oracle-ads:{required_version}` or later using the " + "command: `pip install oracle-ads -U`" + ) + + +def _validate_dependency(func: Callable[..., Any]) -> Callable[..., Any]: + """Decorator to validate the presence and version of the `oracle-ads` package. + + This decorator checks whether `oracle-ads` is installed and ensures its version meets + the minimum requirement. If not, it raises an appropriate error. + + Args: + func (Callable[..., Any]): The function to wrap with the dependency validation. + + Returns: + Callable[..., Any]: The wrapped function. + + Raises: + ImportError: If `oracle-ads` is not installed. + UnsupportedOracleAdsVersionError: If the installed version is below the required version. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + try: + from ads import __version__ as ads_version + + if version.parse(ads_version) < version.parse(MIN_ADS_VERSION): + raise UnsupportedOracleAdsVersionError(ads_version, MIN_ADS_VERSION) + except ImportError as ex: + raise ImportError( + "Could not import `oracle-ads` Python package. " + "Please install it with `pip install oracle-ads`." + ) from ex + + return func(*args, **kwargs) + + return wrapper diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/pyproject.toml new file mode 100644 index 0000000000000..3c919ba6b2016 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/pyproject.toml @@ -0,0 +1,65 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +# Feel free to un-skip examples, and experimental, you will just need to +# work through many typos (--write-changes and --interactive will help) +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.embeddings.oci_data_science" + +[tool.llamahub.class_authors] +OCIDataScienceEmbedding = "mrdzurb" + +[tool.mypy] +disallow_untyped_defs = true +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.9" + +[tool.poetry] +authors = ["Dmitrii Cherkasov "] +description = "llama-index embeddings OCI Data Science integration" +exclude = ["**/BUILD"] +license = "MIT" +name = "llama-index-embeddings-oci-data-science" +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.9,<4.0" +llama-index-core = "^0.12.0" + +[tool.poetry.group.dev.dependencies] +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-asyncio = ">=0.24.0" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" +types-setuptools = "67.1.0.0" + +[tool.poetry.group.dev.dependencies.black] +extras = ["jupyter"] +version = "<=23.9.1,>=23.7.0" + +[tool.poetry.group.dev.dependencies.codespell] +extras = ["toml"] +version = ">=v2.2.6" + +[[tool.poetry.packages]] +include = "llama_index/" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/tests/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/tests/BUILD new file mode 100644 index 0000000000000..dabf212d7e716 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/tests/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/tests/test_embeddings_oci_data_science.py b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/tests/test_embeddings_oci_data_science.py new file mode 100644 index 0000000000000..711d11f4ea5cb --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/tests/test_embeddings_oci_data_science.py @@ -0,0 +1,149 @@ +from unittest.mock import AsyncMock, Mock + +import pytest +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.callbacks.base import CallbackManager +from llama_index.embeddings.oci_data_science import OCIDataScienceEmbedding +from llama_index.embeddings.oci_data_science.client import AsyncClient, Client + + +def test_oci_data_science_embedding_class(): + names_of_base_classes = [b.__name__ for b in OCIDataScienceEmbedding.__mro__] + assert BaseEmbedding.__name__ in names_of_base_classes + + +response_data = { + "data": { + "data": [ + {"embedding": [0.1, 0.2, 0.3], "index": 0, "object": "embedding"}, + {"embedding": [0.4, 0.5, 0.6], "index": 1, "object": "embedding"}, + ], + "model": "sentence-transformers/all-MiniLM-L6-v2", + "object": "list", + "usage": {"prompt_tokens": 14, "total_tokens": 14}, + }, + "headers": {}, + "status": "200 OK", +} + + +@pytest.fixture() +def embeddings(): + endpoint = "https://example.com/api" + auth = {"signer": Mock()} + model_name = "odsc-embeddings" + embed_batch_size = 10 + timeout = 60 + max_retries = 3 + additional_kwargs = {"some_param": "value"} + default_headers = {"Custom-Header": "value"} + callback_manager = CallbackManager([]) + + embeddings_instance = OCIDataScienceEmbedding( + endpoint=endpoint, + model_name=model_name, + auth={"signer": Mock()}, + embed_batch_size=embed_batch_size, + timeout=timeout, + max_retries=max_retries, + additional_kwargs=additional_kwargs, + default_headers=default_headers, + callback_manager=callback_manager, + ) + # Mock the client + embeddings_instance._client = Mock(spec=Client) + embeddings_instance._async_client = AsyncMock(spec=AsyncClient) + return embeddings_instance + + +def test_get_query_embedding(embeddings): + embeddings.client.embeddings.return_value = response_data["data"] + + query = "This is a test query" + embedding_vector = embeddings.get_query_embedding(query) + + embeddings.client.embeddings.assert_called_once_with( + input=query, + payload=embeddings.additional_kwargs, + headers=embeddings.default_headers, + ) + + assert embedding_vector == [0.1, 0.2, 0.3] + + +def test_get_text_embedding(embeddings): + embeddings.client.embeddings.return_value = response_data["data"] + + text = "This is a test text" + embedding_vector = embeddings.get_text_embedding(text) + + embeddings.client.embeddings.assert_called_once_with( + input=text, + payload=embeddings.additional_kwargs, + headers=embeddings.default_headers, + ) + + assert embedding_vector == [0.1, 0.2, 0.3] + + +def test_get_text_embedding_batch(embeddings): + embeddings.client.embeddings.return_value = response_data["data"] + + texts = ["Text one", "Text two"] + embedding_vectors = embeddings.get_text_embedding_batch(texts) + + embeddings.client.embeddings.assert_called_once_with( + input=texts, + payload=embeddings.additional_kwargs, + headers=embeddings.default_headers, + ) + + assert embedding_vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + + +@pytest.mark.asyncio() +async def test_aget_query_embedding(embeddings): + embeddings.async_client.embeddings.return_value = response_data["data"] + + query = "Async test query" + embedding_vector = await embeddings.aget_query_embedding(query) + + embeddings.async_client.embeddings.assert_called_once_with( + input=query, + payload=embeddings.additional_kwargs, + headers=embeddings.default_headers, + ) + + assert embedding_vector == [0.1, 0.2, 0.3] + + +@pytest.mark.asyncio() +async def test_aget_text_embedding(embeddings): + embeddings.async_client.embeddings.return_value = response_data["data"] + + text = "Async test text" + embedding_vector = await embeddings.aget_text_embedding(text) + + embeddings.async_client.embeddings.assert_called_once_with( + input=text, + payload=embeddings.additional_kwargs, + headers=embeddings.default_headers, + ) + + assert embedding_vector == [0.1, 0.2, 0.3] + + +@pytest.mark.asyncio() +async def test_aget_text_embedding_batch(embeddings): + embeddings.async_client.embeddings.return_value = response_data["data"] + + texts = ["Async text one", "Async text two"] + embedding_vectors = await embeddings.aget_text_embedding_batch(texts) + + embeddings.async_client.embeddings.assert_called_once_with( + input=texts, + payload=embeddings.additional_kwargs, + headers=embeddings.default_headers, + ) + + assert embedding_vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/tests/test_oci_data_science_client.py b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/tests/test_oci_data_science_client.py new file mode 100644 index 0000000000000..ca665690b6b32 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-oci-data-science/tests/test_oci_data_science_client.py @@ -0,0 +1,516 @@ +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import httpx +import pytest +from llama_index.embeddings.oci_data_science.client import ( + AsyncClient, + BaseClient, + Client, + ExtendedRequestException, + OCIAuth, + _create_retry_decorator, + _retry_decorator, + _should_retry_exception, +) + + +class TestOCIAuth: + """Unit tests for OCIAuth class.""" + + def setup_method(self): + self.signer_mock = Mock() + self.oci_auth = OCIAuth(self.signer_mock) + + def test_auth_flow(self): + """Ensures that the auth_flow signs the request correctly.""" + request = httpx.Request("POST", "https://example.com") + prepared_request_mock = Mock() + prepared_request_mock.headers = {"Authorization": "Signed"} + with patch("requests.Request") as mock_requests_request: + mock_requests_request.return_value = Mock() + mock_requests_request.return_value.prepare.return_value = ( + prepared_request_mock + ) + self.signer_mock.do_request_sign = Mock() + + list(self.oci_auth.auth_flow(request)) + + self.signer_mock.do_request_sign.assert_called() + assert request.headers.get("Authorization") == "Signed" + + +class TestExtendedRequestException: + """Unit tests for ExtendedRequestException.""" + + def test_exception_attributes(self): + """Ensures the exception stores the correct attributes.""" + original_exception = Exception("Original error") + response_text = "Error response text" + message = "Extended error message" + + exception = ExtendedRequestException(message, original_exception, response_text) + + assert str(exception) == message + assert exception.original_exception == original_exception + assert exception.response_text == response_text + + +class TestShouldRetryException: + """Unit tests for _should_retry_exception function.""" + + def test_http_status_error_in_force_list(self): + """Ensures it returns True for HTTPStatusError with status in STATUS_FORCE_LIST.""" + response_mock = Mock() + response_mock.status_code = 500 + original_exception = httpx.HTTPStatusError( + "Error", request=None, response=response_mock + ) + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is True + + def test_http_status_error_not_in_force_list(self): + """Ensures it returns False for HTTPStatusError with status not in STATUS_FORCE_LIST.""" + response_mock = Mock() + response_mock.status_code = 404 + original_exception = httpx.HTTPStatusError( + "Error", request=None, response=response_mock + ) + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is False + + def test_http_request_error(self): + """Ensures it returns True for RequestError.""" + original_exception = httpx.RequestError("Error") + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is True + + def test_other_exception(self): + """Ensures it returns False for other exceptions.""" + original_exception = Exception("Some other error") + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is False + + +class TestCreateRetryDecorator: + """Unit tests for _create_retry_decorator function.""" + + def test_create_retry_decorator(self): + """Ensures the retry decorator is created with correct parameters.""" + max_retries = 5 + backoff_factor = 2 + random_exponential = False + stop_after_delay_seconds = 100 + min_seconds = 1 + max_seconds = 10 + + retry_decorator = _create_retry_decorator( + max_retries, + backoff_factor, + random_exponential, + stop_after_delay_seconds, + min_seconds, + max_seconds, + ) + + assert callable(retry_decorator) + + +class TestRetryDecorator: + """Unit tests for _retry_decorator function.""" + + def test_retry_decorator_no_retries(self): + """Ensures the function is called directly when retries is 0.""" + + class TestClass: + retries = 0 + backoff_factor = 1 + timeout = 10 + + @_retry_decorator + def test_method(self): + return "Success" + + test_instance = TestClass() + result = test_instance.test_method() + assert result == "Success" + + def test_retry_decorator_with_retries(self): + """Ensures the function retries upon exception.""" + + class TestClass: + retries = 3 + backoff_factor = 0.1 + timeout = 10 + + call_count = 0 + + @_retry_decorator + def test_method(self): + self.call_count += 1 + if self.call_count < 3: + raise ExtendedRequestException( + "Error", + original_exception=httpx.RequestError("Error"), + response_text="test", + ) + return "Success" + + test_instance = TestClass() + result = test_instance.test_method() + assert result == "Success" + assert test_instance.call_count == 3 + + def test_retry_decorator_exceeds_retries(self): + """Ensures the function raises exception after exceeding retries.""" + + class TestClass: + retries = 3 + backoff_factor = 0.1 + timeout = 10 + + call_count = 0 + + @_retry_decorator + def test_method(self): + self.call_count += 1 + raise ExtendedRequestException( + "Error", + original_exception=httpx.RequestError("Error"), + response_text="test", + ) + + test_instance = TestClass() + with pytest.raises(ExtendedRequestException): + test_instance.test_method() + assert test_instance.call_count == 3 # initial attempt + 2 retries + + +class TestBaseClient: + """Unit tests for BaseClient class.""" + + def setup_method(self): + self.endpoint = "https://example.com/api" + self.auth_mock = {"signer": Mock()} + self.retries = 3 + self.backoff_factor = 2 + self.timeout = 30 + + self.base_client = BaseClient( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + + def test_init(self): + """Ensures that the client is initialized correctly.""" + assert self.base_client.endpoint == self.endpoint + assert self.base_client.retries == self.retries + assert self.base_client.backoff_factor == self.backoff_factor + assert self.base_client.timeout == self.timeout + assert isinstance(self.base_client.auth, OCIAuth) + + # def test_init_default_auth(self): + # """Ensures that default auth is used when auth is None.""" + # with patch.object(authutil, "default_signer", return_value=self.auth_mock): + # client = BaseClient(endpoint=self.endpoint) + # assert client.auth is not None + + def test_init_invalid_auth(self): + """Ensures that ValueError is raised when auth signer is invalid.""" + with pytest.raises(ValueError): + BaseClient(endpoint=self.endpoint, auth={"signer": None}) + + def test_prepare_headers(self): + """Ensures that headers are prepared correctly.""" + headers = {"Custom-Header": "Value"} + result = self.base_client._prepare_headers(headers=headers) + expected_headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Custom-Header": "Value", + } + assert result == expected_headers + + +class TestClient: + """Unit tests for Client class.""" + + def setup_method(self): + self.endpoint = "https://example.com/api" + self.auth_mock = {"signer": Mock()} + self.retries = 2 + self.backoff_factor = 0.1 + self.timeout = 10 + + self.client = Client( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + # Mock the internal HTTPX client + self.client._client = Mock() + + def test_auth_not_provided(self): + """Ensures that error will be thrown what auth signer not provided.""" + with pytest.raises(ImportError): + Client( + endpoint=self.endpoint, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + + def test_request_success(self): + """Ensures that _request returns JSON response on success.""" + payload = {"prompt": "Hello"} + response_json = {"choices": [{"text": "Hi"}]} + response_mock = Mock() + response_mock.json.return_value = response_json + response_mock.status_code = 200 + + self.client._client.post.return_value = response_mock + + result = self.client._request(payload) + + assert result == response_json + + def test_request_http_error(self): + """Ensures that _request raises ExtendedRequestException on HTTP error.""" + payload = {"prompt": "Hello"} + response_mock = Mock() + response_mock.status_code = 500 + response_mock.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", request=None, response=response_mock + ) + response_mock.text = "Internal Server Error" + + self.client._client.post.return_value = response_mock + + with pytest.raises(ExtendedRequestException) as exc_info: + self.client._request(payload) + + assert "Request failed" in str(exc_info.value) + assert exc_info.value.response_text == "Internal Server Error" + + def test_embeddings_request(self): + """Ensures that embeddings method calls _request when stream=False.""" + response_json = { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + -0.025584878, + 0.023328023, + -0.03014998, + ], + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + -0.025584878, + 0.023328023, + -0.03014998, + ], + }, + ], + } + response_mock = Mock() + response_mock.json.return_value = response_json + response_mock.status_code = 200 + + self.client._client.post.return_value = response_mock + + result = self.client.embeddings( + input=["Hello", "World"], payload={"param1": "value1"} + ) + + assert result == response_json + + def test_close(self): + """Ensures that close method closes the client.""" + self.client._client.close = Mock() + self.client.close() + self.client._client.close.assert_called_once() + + def test_is_closed(self): + """Ensures that is_closed returns the client's is_closed status.""" + self.client._client.is_closed = False + assert not self.client.is_closed() + self.client._client.is_closed = True + assert self.client.is_closed() + + def test_context_manager(self): + """Ensures that the client can be used as a context manager.""" + self.client.close = Mock() + with self.client as client_instance: + assert client_instance == self.client + self.client.close.assert_called_once() + + def test_del(self): + """Ensures that __del__ method closes the client.""" + client = Client( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + client.close = Mock() + client.__del__() # Manually invoke __del__ + client.close.assert_called_once() + + +@pytest.mark.asyncio() +class TestAsyncClient: + """Unit tests for AsyncClient class.""" + + def setup_method(self): + self.endpoint = "https://example.com/api" + self.auth_mock = {"signer": Mock()} + self.retries = 2 + self.backoff_factor = 0.1 + self.timeout = 10 + + self.client = AsyncClient( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + # Mock the internal HTTPX client + self.client._client = AsyncMock() + self.client._client.is_closed = False + + def async_iter(self, items): + """Helper function to create an async iterator from a list.""" + + async def generator(): + for item in items: + yield item + + return generator() + + async def test_request_success(self): + """Ensures that _request returns JSON response on success.""" + payload = {"prompt": "Hello"} + response_json = {"choices": [{"text": "Hi"}]} + response_mock = AsyncMock() + response_mock.status_code = 200 + response_mock.json = AsyncMock(return_value=response_json) + response_mock.raise_for_status = Mock() + self.client._client.post.return_value = response_mock + result = await self.client._request(payload) + assert await result == response_json + + async def test_request_http_error(self): + """Ensures that _request raises ExtendedRequestException on HTTP error.""" + payload = {"prompt": "Hello"} + response_mock = MagicMock() + response_mock.status_code = 500 + response_mock.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", request=None, response=response_mock + ) + response_mock.text = "Internal Server Error" + + self.client._client.post.return_value = response_mock + + with pytest.raises(ExtendedRequestException) as exc_info: + await self.client._request(payload) + + assert "Request failed" in str(exc_info.value) + assert exc_info.value.response_text == "Internal Server Error" + + async def test_generate_request(self): + """Ensures that generate method calls _request when stream=False.""" + response_json = { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + -0.025584878, + 0.023328023, + -0.03014998, + ], + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + -0.025584878, + 0.023328023, + -0.03014998, + ], + }, + ], + } + response_mock = AsyncMock() + response_mock.status_code = 200 + response_mock.json = AsyncMock(return_value=response_json) + response_mock.raise_for_status = Mock() + + self.client._client.post.return_value = response_mock + + result = await self.client.embeddings( + input=["Hello", "World"], payload={"param1": "value1"} + ) + + assert await result == response_json + + async def test_close(self): + """Ensures that close method closes the client.""" + self.client._client.aclose = AsyncMock() + await self.client.close() + self.client._client.aclose.assert_called_once() + + async def test_is_closed(self): + """Ensures that is_closed returns the client's is_closed status.""" + self.client._client.is_closed = False + assert not self.client.is_closed() + self.client._client.is_closed = True + assert self.client.is_closed() + + async def test_context_manager(self): + """Ensures that the client can be used as a context manager.""" + self.client.close = AsyncMock() + async with self.client as client_instance: + assert client_instance == self.client + self.client.close.assert_called_once() + + async def test_del(self): + """Ensures that __del__ method closes the client.""" + client = AsyncClient( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + client.close = AsyncMock() + await client.__aexit__(None, None, None) # Manually invoke __aexit__ + client.close.assert_called_once()