Skip to content

Commit

Permalink
refactor(models): enhance default model selection logic and add logging
Browse files Browse the repository at this point in the history
  • Loading branch information
yufeikang committed Jul 26, 2024
1 parent ade1df6 commit 302efb7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
18 changes: 15 additions & 3 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
logger = logging.getLogger(__name__)

MAX_TOKENS = os.environ.get("MAX_TOKENS", 1024)
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL")


class ChatBotAbc(abc.ABC):
Expand Down Expand Up @@ -574,30 +575,41 @@ def get_models(self):
MODELS_DICT = {}
MODELS_AVAILABLE = []
DEFAULT_MODELS = {}
AVAILABLE_DEFAULT_MODELS = []
if GeminiChatBot.is_start_available():
logger.info("Google API is available")
_bot = GeminiChatBot()
_models = _bot.get_models()
MODELS_AVAILABLE.extend(_models["models"])
DEFAULT_MODELS = _models["default_models"]
AVAILABLE_DEFAULT_MODELS.append(_models["default_models"])
MODELS_DICT.update({model["model"]: _bot for model in _models["models"]})
if OpenAIChatBot.is_start_available():
logger.info("OpenAI API is available")
_bot = OpenAIChatBot()
_models = _bot.get_models()
MODELS_AVAILABLE.extend(_models["models"])
DEFAULT_MODELS.update(_models["default_models"])
AVAILABLE_DEFAULT_MODELS.append(_models["default_models"])
MODELS_DICT.update({model["model"]: _bot for model in _models["models"]})
if AnthropicChatBot.is_start_available():
logger.info("Anthropic API is available")
_bot = AnthropicChatBot()
_models = _bot.get_models()
MODELS_AVAILABLE.extend(_models["models"])
DEFAULT_MODELS.update(_models["default_models"])
AVAILABLE_DEFAULT_MODELS.append(_models["default_models"])
MODELS_DICT.update({model["model"]: _bot for model in _models["models"]})


DEFAULT_MODELS = next(iter(AVAILABLE_DEFAULT_MODELS))
if DEFAULT_MODEL and DEFAULT_MODEL in MODELS_DICT:
DEFAULT_MODELS = MODELS_DICT[DEFAULT_MODEL]
logger.info(f"Using default model: {DEFAULT_MODEL}")


def get_bot(model_id):
if not model_id:
return next(iter(MODELS_DICT.values()))
logger.debug(f"Getting bot for model: {model_id}")
if model_id not in MODELS_DICT:
logger.error(f"Model not found: {model_id}")
return None
return MODELS_DICT.get(model_id)
2 changes: 1 addition & 1 deletion local_docker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ docker run --rm -it \
-e LOG_LEVEL=DEBUG \
-e ALLOWED_USERS=$ALLOWED_USERS \
-v $PWD/app:/project/app \
raycast --entrypoint sh /entrypoint.sh --reload
raycast --entrypoint "sh /entrypoint.sh --reload"

0 comments on commit 302efb7

Please sign in to comment.