From 9fb92b2afb2bbb617db206255787d438cb1d8fa5 Mon Sep 17 00:00:00 2001 From: Mini256 Date: Wed, 8 Jan 2025 13:27:02 +0800 Subject: [PATCH] fix: fix get embed model bug --- backend/app/rag/chat_config.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/backend/app/rag/chat_config.py b/backend/app/rag/chat_config.py index e5f858b8..ae3bf1c2 100644 --- a/backend/app/rag/chat_config.py +++ b/backend/app/rag/chat_config.py @@ -5,7 +5,6 @@ import dspy from llama_index.llms.bedrock.utils import BEDROCK_FOUNDATION_LLMS from pydantic import BaseModel -from llama_index.llms.openai.utils import DEFAULT_OPENAI_API_BASE from llama_index.llms.openai import OpenAI from llama_index.llms.openai_like import OpenAILike from llama_index.llms.gemini import Gemini @@ -219,10 +218,8 @@ def get_llm( ) -> LLM: match provider: case LLMProvider.OPENAI: - api_base = config.pop("api_base", DEFAULT_OPENAI_API_BASE) return OpenAI( model=model, - api_base=api_base, api_key=credentials, **config, ) @@ -247,6 +244,7 @@ def get_llm( aws_secret_access_key=secret_access_key, region_name=region_name, context_size=context_size, + **config, ) # Note: Because llama index Bedrock class doesn't set up these values to the corresponding # attributes in its constructor function, we pass the values again via setter to pass them to @@ -315,10 +313,8 @@ def get_embed_model( ) -> BaseEmbedding: match provider: case EmbeddingProvider.OPENAI: - api_base = config.pop("api_base", DEFAULT_OPENAI_API_BASE) return OpenAIEmbedding( model=model, - api_base=api_base, api_key=credentials, **config, ) @@ -332,6 +328,7 @@ def get_embed_model( return CohereEmbedding( model_name=model, cohere_api_key=credentials, + **config, ) case EmbeddingProvider.BEDROCK: return BedrockEmbedding( @@ -359,10 +356,8 @@ def get_embed_model( **config, ) case EmbeddingProvider.OPENAI_LIKE: - api_base = config.pop("api_base", "https://open.bigmodel.cn/api/paas/v4") return OpenAILikeEmbedding( model=model, - api_base=api_base, api_key=credentials, **config, ) @@ -408,12 +403,14 @@ def get_reranker_model( model=model, top_n=top_n, api_key=credentials, + **config, ) case RerankerProvider.COHERE: return CohereRerank( model=model, top_n=top_n, api_key=credentials, + **config, ) case RerankerProvider.BAISHENG: return BaishengRerank(