Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: azure llm and embedding models #592

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backend/app/rag/embeddings/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class EmbeddingProvider(str, enum.Enum):
GITEEAI = "giteeai"
LOCAL = "local"
OPENAI_LIKE = "openai_like"
AZURE_OPENAI = "azure_openai"


class EmbeddingProviderOption(BaseModel):
Expand Down Expand Up @@ -123,6 +124,22 @@ class EmbeddingProviderOption(BaseModel):
credentials_type="str",
default_credentials="****",
),
EmbeddingProviderOption(
provider=EmbeddingProvider.AZURE_OPENAI,
provider_display_name="Azure OpenAI",
provider_description="Azure OpenAI is a cloud-based AI service that provides a suite of AI models and tools for developers to build intelligent applications.",
provider_url="https://azure.microsoft.com/en-us/products/ai-services/openai-service",
default_embedding_model="text-embedding-3-small",
embedding_model_description="Before using this option, you need to deploy an Azure OpenAI API and model, see https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource.",
default_config={
"azure_endpoint": "https://<your-resource-name>.openai.azure.com/",
"api_version": "<your-api-version>"
},
credentials_display_name="Azure OpenAI API Key",
credentials_description="The API key of Azure OpenAI",
credentials_type="str",
default_credentials="****",
),
EmbeddingProviderOption(
provider=EmbeddingProvider.LOCAL,
provider_display_name="Local Embedding",
Expand Down
8 changes: 8 additions & 0 deletions backend/app/rag/embeddings/resolver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Optional

from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from sqlmodel import Session

from llama_index.core.base.embeddings.base import BaseEmbedding
Expand Down Expand Up @@ -65,6 +67,12 @@ def get_embed_model(
api_key=credentials,
**config,
)
case EmbeddingProvider.AZURE_OPENAI:
return AzureOpenAIEmbedding(
model=model,
api_key=credentials,
**config,
)
case EmbeddingProvider.OPENAI_LIKE:
return OpenAILikeEmbedding(
model=model,
Expand Down
19 changes: 19 additions & 0 deletions backend/app/rag/llms/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class LLMProvider(str, enum.Enum):
BEDROCK = "bedrock"
OLLAMA = "ollama"
GITEEAI = "giteeai"
AZURE_OPENAI = "azure_openai"


class LLMProviderOption(BaseModel):
Expand Down Expand Up @@ -147,4 +148,22 @@ class LLMProviderOption(BaseModel):
"aws_region_name": "us-west-2",
},
),
LLMProviderOption(
provider=LLMProvider.AZURE_OPENAI,
provider_display_name="Azure OpenAI",
provider_description="Azure OpenAI is a cloud-based AI service that provides access to OpenAI's advanced language models.",
provider_url="https://azure.microsoft.com/en-us/products/ai-services/openai-service",
default_llm_model="gpt-4o",
llm_model_description="",
config_description="Refer to this document https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart to have more information about the Azure OpenAI API.",
default_config={
"azure_endpoint": "https://<your-resource-name>.openai.azure.com/",
"api_version": "<your-api-version>",
"engine": "<your-deployment-name>",
},
credentials_display_name="Azure OpenAI API Key",
credentials_description="The API key of Azure OpenAI",
credentials_type="str",
default_credentials="****",
),
]
7 changes: 7 additions & 0 deletions backend/app/rag/llms/resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Optional
from llama_index.core.llms.llm import LLM
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_like import OpenAILike
from llama_index.llms.gemini import Gemini
Expand Down Expand Up @@ -86,6 +87,12 @@ def get_llm(
api_key=credentials,
**config,
)
case LLMProvider.AZURE_OPENAI:
return AzureOpenAI(
model=model,
api_key=credentials,
**config,
)
case _:
raise ValueError(f"Got unknown LLM provider: {provider}")

Expand Down
11 changes: 11 additions & 0 deletions backend/app/utils/dspy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import hashlib
from typing import Any, Literal

from llama_index.llms.azure_openai import AzureOpenAI

import dspy
import requests
from dsp.modules.lm import LM
Expand Down Expand Up @@ -83,6 +85,15 @@ def get_dspy_lm_by_llama_llm(llama_llm: BaseLLM) -> dspy.LM:
max_tokens=llama_llm.context_window,
num_ctx=llama_llm.context_window,
)
elif type(llama_llm) is AzureOpenAI:
return dspy.AzureOpenAI(
model=llama_llm.model,
max_tokens=llama_llm.max_tokens or 4096,
api_key=llama_llm.api_key,
api_base=enforce_trailing_slash(llama_llm.azure_endpoint),
api_version=llama_llm.api_version,
deployment_id=llama_llm.engine,
)
else:
raise ValueError(f"Got unknown LLM provider: {llama_llm.__class__.__name__}")

Expand Down
3 changes: 3 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ dependencies = [
"llama-index-postprocessor-xinference-rerank>=0.2.0",
"llama-index-postprocessor-bedrock-rerank>=0.3.0",
"llama-index-llms-vertex>=0.4.2",
"socksio>=1.0.0",
"llama-index-llms-azure-openai>=0.3.0",
"llama-index-embeddings-azure-openai>=0.3.0",
]
readme = "README.md"
requires-python = ">= 3.8"
Expand Down
27 changes: 27 additions & 0 deletions backend/requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ argon2-cffi-bindings==21.2.0
asyncmy==0.2.9
attrs==23.2.0
# via aiohttp
azure-core==1.32.0
# via azure-identity
azure-identity==1.19.0
# via llama-index-llms-azure-openai
backoff==2.2.1
# via dspy-ai
# via langfuse
Expand Down Expand Up @@ -103,6 +107,8 @@ colorama==0.4.6
colorlog==6.8.2
# via optuna
cryptography==42.0.8
# via azure-identity
# via msal
# via pyjwt
dataclasses-json==0.6.7
# via langchain-community
Expand Down Expand Up @@ -249,6 +255,7 @@ httpx==0.27.0
# via langsmith
# via llama-cloud
# via llama-index-core
# via llama-index-llms-azure-openai
# via ollama
# via openai
httpx-oauth==0.14.1
Expand Down Expand Up @@ -326,13 +333,15 @@ llama-index-core==0.12.10.post1
# via llama-index
# via llama-index-agent-openai
# via llama-index-cli
# via llama-index-embeddings-azure-openai
# via llama-index-embeddings-bedrock
# via llama-index-embeddings-cohere
# via llama-index-embeddings-jinaai
# via llama-index-embeddings-ollama
# via llama-index-embeddings-openai
# via llama-index-indices-managed-llama-cloud
# via llama-index-llms-anthropic
# via llama-index-llms-azure-openai
# via llama-index-llms-bedrock
# via llama-index-llms-gemini
# via llama-index-llms-ollama
Expand All @@ -349,24 +358,29 @@ llama-index-core==0.12.10.post1
# via llama-index-readers-file
# via llama-index-readers-llama-parse
# via llama-parse
llama-index-embeddings-azure-openai==0.3.0
llama-index-embeddings-bedrock==0.4.0
llama-index-embeddings-cohere==0.4.0
llama-index-embeddings-jinaai==0.4.0
llama-index-embeddings-ollama==0.5.0
llama-index-embeddings-openai==0.3.1
# via llama-index
# via llama-index-cli
# via llama-index-embeddings-azure-openai
llama-index-indices-managed-llama-cloud==0.6.3
# via llama-index
llama-index-llms-anthropic==0.6.3
# via llama-index-llms-bedrock
llama-index-llms-azure-openai==0.3.0
# via llama-index-embeddings-azure-openai
llama-index-llms-bedrock==0.3.3
llama-index-llms-gemini==0.4.2
llama-index-llms-ollama==0.5.0
llama-index-llms-openai==0.3.13
# via llama-index
# via llama-index-agent-openai
# via llama-index-cli
# via llama-index-llms-azure-openai
# via llama-index-llms-openai-like
# via llama-index-multi-modal-llms-openai
# via llama-index-program-openai
Expand Down Expand Up @@ -407,6 +421,11 @@ marshmallow==3.21.3
# via dataclasses-json
mdurl==0.1.2
# via markdown-it-py
msal==1.31.1
# via azure-identity
# via msal-extensions
msal-extensions==1.2.0
# via azure-identity
multidict==6.0.5
# via aiohttp
# via yarl
Expand Down Expand Up @@ -498,6 +517,7 @@ pluggy==1.5.0
# via pytest
portalocker==2.10.1
# via deepeval
# via msal-extensions
pre-commit==4.0.1
prometheus-client==0.20.0
# via flower
Expand Down Expand Up @@ -565,6 +585,7 @@ pygments==2.18.0
# via rich
pyjwt==2.8.0
# via fastapi-users
# via msal
pymysql==1.1.1
pyparsing==3.1.2
# via httplib2
Expand Down Expand Up @@ -617,6 +638,7 @@ regex==2024.5.15
# via tiktoken
# via transformers
requests==2.32.3
# via azure-core
# via cohere
# via datasets
# via deepeval
Expand All @@ -630,6 +652,7 @@ requests==2.32.3
# via langchain-community
# via langsmith
# via llama-index-core
# via msal
# via requests-toolbelt
# via tiktoken
# via transformers
Expand All @@ -653,13 +676,15 @@ shapely==2.0.6
shellingham==1.5.4
# via typer
six==1.16.0
# via azure-core
# via markdownify
# via python-dateutil
sniffio==1.3.1
# via anthropic
# via anyio
# via httpx
# via openai
socksio==1.0.0
soupsieve==2.5
# via beautifulsoup4
sqlalchemy==2.0.30
Expand Down Expand Up @@ -716,6 +741,8 @@ types-requests==2.32.0.20240712
typing-extensions==4.12.2
# via alembic
# via anthropic
# via azure-core
# via azure-identity
# via cohere
# via fastapi
# via fastapi-pagination
Expand Down
Loading
Loading