Skip to content

Commit

Permalink
fix: fix get embed model bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Mini256 committed Jan 8, 2025
1 parent a0cf8ac commit 9fb92b2
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions backend/app/rag/chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import dspy
from llama_index.llms.bedrock.utils import BEDROCK_FOUNDATION_LLMS
from pydantic import BaseModel
from llama_index.llms.openai.utils import DEFAULT_OPENAI_API_BASE
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_like import OpenAILike
from llama_index.llms.gemini import Gemini
Expand Down Expand Up @@ -219,10 +218,8 @@ def get_llm(
) -> LLM:
match provider:
case LLMProvider.OPENAI:
api_base = config.pop("api_base", DEFAULT_OPENAI_API_BASE)
return OpenAI(
model=model,
api_base=api_base,
api_key=credentials,
**config,
)
Expand All @@ -247,6 +244,7 @@ def get_llm(
aws_secret_access_key=secret_access_key,
region_name=region_name,
context_size=context_size,
**config,
)
# Note: Because llama index Bedrock class doesn't set up these values to the corresponding
# attributes in its constructor function, we pass the values again via setter to pass them to
Expand Down Expand Up @@ -315,10 +313,8 @@ def get_embed_model(
) -> BaseEmbedding:
match provider:
case EmbeddingProvider.OPENAI:
api_base = config.pop("api_base", DEFAULT_OPENAI_API_BASE)
return OpenAIEmbedding(
model=model,
api_base=api_base,
api_key=credentials,
**config,
)
Expand All @@ -332,6 +328,7 @@ def get_embed_model(
return CohereEmbedding(
model_name=model,
cohere_api_key=credentials,
**config,
)
case EmbeddingProvider.BEDROCK:
return BedrockEmbedding(
Expand Down Expand Up @@ -359,10 +356,8 @@ def get_embed_model(
**config,
)
case EmbeddingProvider.OPENAI_LIKE:
api_base = config.pop("api_base", "https://open.bigmodel.cn/api/paas/v4")
return OpenAILikeEmbedding(
model=model,
api_base=api_base,
api_key=credentials,
**config,
)
Expand Down Expand Up @@ -408,12 +403,14 @@ def get_reranker_model(
model=model,
top_n=top_n,
api_key=credentials,
**config,
)
case RerankerProvider.COHERE:
return CohereRerank(
model=model,
top_n=top_n,
api_key=credentials,
**config,
)
case RerankerProvider.BAISHENG:
return BaishengRerank(
Expand Down

0 comments on commit 9fb92b2

Please sign in to comment.