diff --git a/backend/app/rag/embeddings/provider.py b/backend/app/rag/embeddings/provider.py index c7e67ce4..84e98da3 100644 --- a/backend/app/rag/embeddings/provider.py +++ b/backend/app/rag/embeddings/provider.py @@ -13,6 +13,7 @@ class EmbeddingProvider(str, enum.Enum): GITEEAI = "giteeai" LOCAL = "local" OPENAI_LIKE = "openai_like" + AZURE_OPENAI = "azure_openai" class EmbeddingProviderOption(BaseModel): @@ -123,6 +124,22 @@ class EmbeddingProviderOption(BaseModel): credentials_type="str", default_credentials="****", ), + EmbeddingProviderOption( + provider=EmbeddingProvider.AZURE_OPENAI, + provider_display_name="Azure OpenAI", + provider_description="Azure OpenAI is a cloud-based AI service that provides a suite of AI models and tools for developers to build intelligent applications.", + provider_url="https://azure.microsoft.com/en-us/products/ai-services/openai-service", + default_embedding_model="text-embedding-3-small", + embedding_model_description="Before using this option, you need to deploy an Azure OpenAI API and model, see https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource.", + default_config={ + "azure_endpoint": "https://.openai.azure.com/", + "api_version": "" + }, + credentials_display_name="Azure OpenAI API Key", + credentials_description="The API key of Azure OpenAI", + credentials_type="str", + default_credentials="****", + ), EmbeddingProviderOption( provider=EmbeddingProvider.LOCAL, provider_display_name="Local Embedding", diff --git a/backend/app/rag/embeddings/resolver.py b/backend/app/rag/embeddings/resolver.py index aad709e8..a5a36d20 100644 --- a/backend/app/rag/embeddings/resolver.py +++ b/backend/app/rag/embeddings/resolver.py @@ -1,4 +1,6 @@ from typing import Optional + +from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding from sqlmodel import Session from llama_index.core.base.embeddings.base import BaseEmbedding @@ -65,6 +67,12 @@ def get_embed_model( api_key=credentials, **config, ) + case EmbeddingProvider.AZURE_OPENAI: + return AzureOpenAIEmbedding( + model=model, + api_key=credentials, + **config, + ) case EmbeddingProvider.OPENAI_LIKE: return OpenAILikeEmbedding( model=model, diff --git a/backend/app/rag/llms/provider.py b/backend/app/rag/llms/provider.py index 840aadfa..a3450870 100644 --- a/backend/app/rag/llms/provider.py +++ b/backend/app/rag/llms/provider.py @@ -12,6 +12,7 @@ class LLMProvider(str, enum.Enum): BEDROCK = "bedrock" OLLAMA = "ollama" GITEEAI = "giteeai" + AZURE_OPENAI = "azure_openai" class LLMProviderOption(BaseModel): @@ -147,4 +148,22 @@ class LLMProviderOption(BaseModel): "aws_region_name": "us-west-2", }, ), + LLMProviderOption( + provider=LLMProvider.AZURE_OPENAI, + provider_display_name="Azure OpenAI", + provider_description="Azure OpenAI is a cloud-based AI service that provides access to OpenAI's advanced language models.", + provider_url="https://azure.microsoft.com/en-us/products/ai-services/openai-service", + default_llm_model="gpt-4o", + llm_model_description="", + config_description="Refer to this document https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart to have more information about the Azure OpenAI API.", + default_config={ + "azure_endpoint": "https://.openai.azure.com/", + "api_version": "", + "engine": "", + }, + credentials_display_name="Azure OpenAI API Key", + credentials_description="The API key of Azure OpenAI", + credentials_type="str", + default_credentials="****", + ), ] diff --git a/backend/app/rag/llms/resolver.py b/backend/app/rag/llms/resolver.py index fd3c9785..874cde76 100644 --- a/backend/app/rag/llms/resolver.py +++ b/backend/app/rag/llms/resolver.py @@ -1,6 +1,7 @@ import os from typing import Optional from llama_index.core.llms.llm import LLM +from llama_index.llms.azure_openai import AzureOpenAI from llama_index.llms.openai import OpenAI from llama_index.llms.openai_like import OpenAILike from llama_index.llms.gemini import Gemini @@ -86,6 +87,12 @@ def get_llm( api_key=credentials, **config, ) + case LLMProvider.AZURE_OPENAI: + return AzureOpenAI( + model=model, + api_key=credentials, + **config, + ) case _: raise ValueError(f"Got unknown LLM provider: {provider}") diff --git a/backend/app/utils/dspy.py b/backend/app/utils/dspy.py index 42a1233c..4946c92c 100644 --- a/backend/app/utils/dspy.py +++ b/backend/app/utils/dspy.py @@ -3,6 +3,8 @@ import hashlib from typing import Any, Literal +from llama_index.llms.azure_openai import AzureOpenAI + import dspy import requests from dsp.modules.lm import LM @@ -83,6 +85,15 @@ def get_dspy_lm_by_llama_llm(llama_llm: BaseLLM) -> dspy.LM: max_tokens=llama_llm.context_window, num_ctx=llama_llm.context_window, ) + elif type(llama_llm) is AzureOpenAI: + return dspy.AzureOpenAI( + model=llama_llm.model, + max_tokens=llama_llm.max_tokens or 4096, + api_key=llama_llm.api_key, + api_base=enforce_trailing_slash(llama_llm.azure_endpoint), + api_version=llama_llm.api_version, + deployment_id=llama_llm.engine, + ) else: raise ValueError(f"Got unknown LLM provider: {llama_llm.__class__.__name__}") diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 7ec00d03..560b776d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -61,6 +61,9 @@ dependencies = [ "llama-index-postprocessor-xinference-rerank>=0.2.0", "llama-index-postprocessor-bedrock-rerank>=0.3.0", "llama-index-llms-vertex>=0.4.2", + "socksio>=1.0.0", + "llama-index-llms-azure-openai>=0.3.0", + "llama-index-embeddings-azure-openai>=0.3.0", ] readme = "README.md" requires-python = ">= 3.8" diff --git a/backend/requirements-dev.lock b/backend/requirements-dev.lock index 8570c5d9..fd4e29f1 100644 --- a/backend/requirements-dev.lock +++ b/backend/requirements-dev.lock @@ -41,6 +41,10 @@ argon2-cffi-bindings==21.2.0 asyncmy==0.2.9 attrs==23.2.0 # via aiohttp +azure-core==1.32.0 + # via azure-identity +azure-identity==1.19.0 + # via llama-index-llms-azure-openai backoff==2.2.1 # via dspy-ai # via langfuse @@ -103,6 +107,8 @@ colorama==0.4.6 colorlog==6.8.2 # via optuna cryptography==42.0.8 + # via azure-identity + # via msal # via pyjwt dataclasses-json==0.6.7 # via langchain-community @@ -249,6 +255,7 @@ httpx==0.27.0 # via langsmith # via llama-cloud # via llama-index-core + # via llama-index-llms-azure-openai # via ollama # via openai httpx-oauth==0.14.1 @@ -326,6 +333,7 @@ llama-index-core==0.12.10.post1 # via llama-index # via llama-index-agent-openai # via llama-index-cli + # via llama-index-embeddings-azure-openai # via llama-index-embeddings-bedrock # via llama-index-embeddings-cohere # via llama-index-embeddings-jinaai @@ -333,6 +341,7 @@ llama-index-core==0.12.10.post1 # via llama-index-embeddings-openai # via llama-index-indices-managed-llama-cloud # via llama-index-llms-anthropic + # via llama-index-llms-azure-openai # via llama-index-llms-bedrock # via llama-index-llms-gemini # via llama-index-llms-ollama @@ -349,6 +358,7 @@ llama-index-core==0.12.10.post1 # via llama-index-readers-file # via llama-index-readers-llama-parse # via llama-parse +llama-index-embeddings-azure-openai==0.3.0 llama-index-embeddings-bedrock==0.4.0 llama-index-embeddings-cohere==0.4.0 llama-index-embeddings-jinaai==0.4.0 @@ -356,10 +366,13 @@ llama-index-embeddings-ollama==0.5.0 llama-index-embeddings-openai==0.3.1 # via llama-index # via llama-index-cli + # via llama-index-embeddings-azure-openai llama-index-indices-managed-llama-cloud==0.6.3 # via llama-index llama-index-llms-anthropic==0.6.3 # via llama-index-llms-bedrock +llama-index-llms-azure-openai==0.3.0 + # via llama-index-embeddings-azure-openai llama-index-llms-bedrock==0.3.3 llama-index-llms-gemini==0.4.2 llama-index-llms-ollama==0.5.0 @@ -367,6 +380,7 @@ llama-index-llms-openai==0.3.13 # via llama-index # via llama-index-agent-openai # via llama-index-cli + # via llama-index-llms-azure-openai # via llama-index-llms-openai-like # via llama-index-multi-modal-llms-openai # via llama-index-program-openai @@ -407,6 +421,11 @@ marshmallow==3.21.3 # via dataclasses-json mdurl==0.1.2 # via markdown-it-py +msal==1.31.1 + # via azure-identity + # via msal-extensions +msal-extensions==1.2.0 + # via azure-identity multidict==6.0.5 # via aiohttp # via yarl @@ -498,6 +517,7 @@ pluggy==1.5.0 # via pytest portalocker==2.10.1 # via deepeval + # via msal-extensions pre-commit==4.0.1 prometheus-client==0.20.0 # via flower @@ -565,6 +585,7 @@ pygments==2.18.0 # via rich pyjwt==2.8.0 # via fastapi-users + # via msal pymysql==1.1.1 pyparsing==3.1.2 # via httplib2 @@ -617,6 +638,7 @@ regex==2024.5.15 # via tiktoken # via transformers requests==2.32.3 + # via azure-core # via cohere # via datasets # via deepeval @@ -630,6 +652,7 @@ requests==2.32.3 # via langchain-community # via langsmith # via llama-index-core + # via msal # via requests-toolbelt # via tiktoken # via transformers @@ -653,6 +676,7 @@ shapely==2.0.6 shellingham==1.5.4 # via typer six==1.16.0 + # via azure-core # via markdownify # via python-dateutil sniffio==1.3.1 @@ -660,6 +684,7 @@ sniffio==1.3.1 # via anyio # via httpx # via openai +socksio==1.0.0 soupsieve==2.5 # via beautifulsoup4 sqlalchemy==2.0.30 @@ -716,6 +741,8 @@ types-requests==2.32.0.20240712 typing-extensions==4.12.2 # via alembic # via anthropic + # via azure-core + # via azure-identity # via cohere # via fastapi # via fastapi-pagination diff --git a/backend/requirements.lock b/backend/requirements.lock index fba5f946..82d7aef7 100644 --- a/backend/requirements.lock +++ b/backend/requirements.lock @@ -41,6 +41,10 @@ argon2-cffi-bindings==21.2.0 asyncmy==0.2.9 attrs==23.2.0 # via aiohttp +azure-core==1.32.0 + # via azure-identity +azure-identity==1.19.0 + # via llama-index-llms-azure-openai backoff==2.2.1 # via dspy-ai # via langfuse @@ -101,6 +105,8 @@ colorama==0.4.6 colorlog==6.8.2 # via optuna cryptography==42.0.8 + # via azure-identity + # via msal # via pyjwt dataclasses-json==0.6.7 # via langchain-community @@ -244,6 +250,7 @@ httpx==0.27.0 # via langsmith # via llama-cloud # via llama-index-core + # via llama-index-llms-azure-openai # via ollama # via openai httpx-oauth==0.14.1 @@ -319,6 +326,7 @@ llama-index-core==0.12.10.post1 # via llama-index # via llama-index-agent-openai # via llama-index-cli + # via llama-index-embeddings-azure-openai # via llama-index-embeddings-bedrock # via llama-index-embeddings-cohere # via llama-index-embeddings-jinaai @@ -326,6 +334,7 @@ llama-index-core==0.12.10.post1 # via llama-index-embeddings-openai # via llama-index-indices-managed-llama-cloud # via llama-index-llms-anthropic + # via llama-index-llms-azure-openai # via llama-index-llms-bedrock # via llama-index-llms-gemini # via llama-index-llms-ollama @@ -342,6 +351,7 @@ llama-index-core==0.12.10.post1 # via llama-index-readers-file # via llama-index-readers-llama-parse # via llama-parse +llama-index-embeddings-azure-openai==0.3.0 llama-index-embeddings-bedrock==0.4.0 llama-index-embeddings-cohere==0.4.0 llama-index-embeddings-jinaai==0.4.0 @@ -349,10 +359,13 @@ llama-index-embeddings-ollama==0.5.0 llama-index-embeddings-openai==0.3.1 # via llama-index # via llama-index-cli + # via llama-index-embeddings-azure-openai llama-index-indices-managed-llama-cloud==0.6.3 # via llama-index llama-index-llms-anthropic==0.6.3 # via llama-index-llms-bedrock +llama-index-llms-azure-openai==0.3.0 + # via llama-index-embeddings-azure-openai llama-index-llms-bedrock==0.3.3 llama-index-llms-gemini==0.4.2 llama-index-llms-ollama==0.5.0 @@ -360,6 +373,7 @@ llama-index-llms-openai==0.3.13 # via llama-index # via llama-index-agent-openai # via llama-index-cli + # via llama-index-llms-azure-openai # via llama-index-llms-openai-like # via llama-index-multi-modal-llms-openai # via llama-index-program-openai @@ -400,6 +414,11 @@ marshmallow==3.21.3 # via dataclasses-json mdurl==0.1.2 # via markdown-it-py +msal==1.31.1 + # via azure-identity + # via msal-extensions +msal-extensions==1.2.0 + # via azure-identity multidict==6.0.5 # via aiohttp # via yarl @@ -487,6 +506,7 @@ pluggy==1.5.0 # via pytest portalocker==2.10.1 # via deepeval + # via msal-extensions prometheus-client==0.20.0 # via flower prompt-toolkit==3.0.47 @@ -553,6 +573,7 @@ pygments==2.18.0 # via rich pyjwt==2.8.0 # via fastapi-users + # via msal pymysql==1.1.1 pyparsing==3.1.2 # via httplib2 @@ -604,6 +625,7 @@ regex==2024.5.15 # via tiktoken # via transformers requests==2.32.3 + # via azure-core # via cohere # via datasets # via deepeval @@ -617,6 +639,7 @@ requests==2.32.3 # via langchain-community # via langsmith # via llama-index-core + # via msal # via requests-toolbelt # via tiktoken # via transformers @@ -639,6 +662,7 @@ shapely==2.0.6 shellingham==1.5.4 # via typer six==1.16.0 + # via azure-core # via markdownify # via python-dateutil sniffio==1.3.1 @@ -646,6 +670,7 @@ sniffio==1.3.1 # via anyio # via httpx # via openai +socksio==1.0.0 soupsieve==2.5 # via beautifulsoup4 sqlalchemy==2.0.30 @@ -702,6 +727,8 @@ types-requests==2.32.0.20240712 typing-extensions==4.12.2 # via alembic # via anthropic + # via azure-core + # via azure-identity # via cohere # via fastapi # via fastapi-pagination