Skip to content

Commit

Permalink
feat: (models) add support for OpenAI API compatible providers and dy…
Browse files Browse the repository at this point in the history
…namic model fetching
  • Loading branch information
yufeikang committed Dec 5, 2024
1 parent ebe7155 commit cff9e85
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 87 deletions.
14 changes: 13 additions & 1 deletion README.ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,23 @@ epts-howmitmproxyworks/)をご参照ください。
| モデルプロバイダー | モデル | テスト状況 | 環境変数 |
| --- | --- | --- | --- |
| `openai` | gpt-3.5-turbo,gpt-4-turbo, gpt-4o | テスト済み | `OPENAI_API_KEY` |
| `openai` | from API | テスト済み | `OPENAI_API_KEY` |
| `azure openai` | 同上 | テスト済み | `AZURE_OPENAI_API_KEY`, `AZURE_DEPLOYMENT_ID`, `OPENAI_AZURE_ENDPOINT` |
| `google` | gemini-pro,gemini-1.5-pro | テスト済み | `GOOGLE_API_KEY` |
| `anthropic` | claude-3-sonnet, claude-3-opus, claude-3-5-opus | テスト済み | `ANTHROPIC_API_KEY` | x |

#### サポートされている OpenAI API 互換プロバイダー

##### [Ollama](https://ollama.com/) の例

環境変数を追加

- `OPENAI_PROVIDER=ollama`
- `OPENAI_BASE_URL=http://localhost:11434/v1`
- `OPENAI_API_KEY=ollama` # 必須ですが、使用されません

モデルは `http://localhost:11434/v1/models` から取得されます。

### Ai チャット

!(./assert/img/chat.jpeg)
Expand Down
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,23 @@ certificate. For more details on man-in-the-middle proxies, you can refer to (<h
| Model Provider | Models | Test Status | Environment Variables | Image generation |
| --- | --- | --- | --- | --- |
| `openai` | gpt-3.5-turbo, gpt-4-turbo, gpt-4o | Tested | `OPENAI_API_KEY` | Supported |
| `openai` | **from api** | Tested | `OPENAI_API_KEY` | Supported |
| `azure openai` | Same as above | Tested | `AZURE_OPENAI_API_KEY`, `AZURE_DEPLOYMENT_ID`, `OPENAI_AZURE_ENDPOINT` | Supported |
| `google` | gemini-pro, gemini-1.5-pro | Tested | `GOOGLE_API_KEY` | x |
| `anthropic` | claude-3-sonnet, claude-3-opus, claude-3-5-opus | Tested | `ANTHROPIC_API_KEY` | x |

#### support openai api compatible providers

##### Example for [Ollama](https://ollama.com/)

add environment variables

- `OPENAI_PROVIDER=ollama`
- `OPENAI_BASE_URL=http://localhost:11434/v1`
- `OPENAI_API_KEY=ollama` # required, but unused

models will be fetched from `http://localhost:11434/v1/models`

### Ai chat

![ai chat](./assert/img/chat.jpeg)
Expand Down
14 changes: 13 additions & 1 deletion README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,23 @@
| 模型provider | 模型 | 测试状态 | 环境变量 | 图片生成 |
| --- | --- | --- | --- | --- |
| `openai` | gpt-3.5-turbo,gpt-4-turbo, gpt-4o | 已测试 | `OPENAI_API_KEY` | 已支持 |
| `openai` | **from api** | 已测试 | `OPENAI_API_KEY` | 已支持 |
| `azure openai` | 同上 | 已测试 | `AZURE_OPENAI_API_KEY`, `AZURE_DEPLOYMENT_ID`, `OPENAI_AZURE_ENDPOINT` | 已支持 |
| `google` | gemini-pro,gemini-1.5-pro | 已测试 | `GOOGLE_API_KEY` | x |
| `anthropic` | claude-3-sonnet, claude-3-opus, claude-3-5-opus | 已测试 | `ANTHROPIC_API_KEY` | x |

#### 支持openai api兼容的provider

##### [Ollama](https://ollama.com/) 示例

添加环境变量

- `OPENAI_PROVIDER=ollama`
- `OPENAI_BASE_URL=http://localhost:11434/v1`
- `OPENAI_API_KEY=ollama` # required, but unused

模型将从`http://localhost:11434/v1/models`获取

### Ai chat

![ai chat](./assert/img/chat.jpeg)
Expand Down
7 changes: 6 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fastapi.responses import StreamingResponse

from app.middleware import AuthMiddleware
from app.models import DEFAULT_MODELS, MODELS_AVAILABLE, get_bot
from app.models import DEFAULT_MODELS, MODELS_AVAILABLE, get_bot, init_models
from app.sync import router as sync_router
from app.utils import (
ProxyRequest,
Expand All @@ -36,6 +36,11 @@ async def shutdown_event():
await http_client.aclose()


@app.on_event("startup")
async def on_startup():
await init_models()


app.include_router(sync_router, prefix="/api/v1/me")


Expand Down
172 changes: 89 additions & 83 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
from functools import cache

import anthropic
import google.generativeai as genai
Expand Down Expand Up @@ -30,7 +31,7 @@ async def chat_completions(self, raycast_data: dict):
async def translate_completions(self, raycast_data: dict):
pass

def get_models(self):
async def get_models(self):
pass


Expand All @@ -57,7 +58,14 @@ def _get_model_extra_info(name=""):
"image_generation": {
"model": "dall-e-3"
},
"system_message": {
"supported": true
},
"temperature": {
"supported": true
}
},
"""
ext = {
"description": "model description",
Expand All @@ -83,6 +91,22 @@ def _get_model_extra_info(name=""):
"image_generation": {
"model": "dall-e-3",
},
"system_message": {
"supported": True,
},
"temperature": {
"supported": True,
},
}
# o1 models don't support system_message and temperature
if "o1" in name:
ext["abilities"] = {
"system_message": {
"supported": False,
},
"temperature": {
"supported": False,
},
}
return ext

Expand All @@ -97,6 +121,9 @@ def is_start_available(cls):

def __init__(self) -> None:
super().__init__()
self.provider = os.environ.get(
"OPENAI_PROVIDER", "openai"
) # for openai api compatible provider
openai.api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get(
"AZURE_OPENAI_API_KEY"
)
Expand Down Expand Up @@ -306,60 +333,36 @@ async def __chat(self, messages, model, temperature, **kwargs):
return
yield chunk.choices[0], None

def get_models(self):
@cache
async def get_models(self):
default_models = _get_default_model_dict("openai-gpt-4o-mini")
models = [
{
"id": "openai-gpt-4o-mini",
"model": "gpt-4o-mini",
"name": "GPT-4o Mini",
"provider": "openai",
"provider_name": "OpenAI",
"provider_brand": "openai",
"context": 16,
**_get_model_extra_info("gpt-4o-mini"),
},
{
"id": "openai-gpt-4o",
"model": "gpt-4o",
"name": "GPT-4o",
"provider": "openai",
"provider_name": "OpenAI",
"provider_brand": "openai",
"context": 8,
**_get_model_extra_info("gpt-4o"),
},
{
"id": "openai-o1-mini",
"model": "o1-mini",
"name": "o1 mini",
"provider": "openai",
"provider_name": "OpenAI",
"provider_brand": "openai",
"context": 16,
**_get_model_extra_info("o1-mini"),
},
{
"id": "openai-o1-preview",
"model": "o1-preview",
"name": "o1 Preview",
"provider": "openai",
"provider_name": "OpenAI",
"provider_brand": "openai",
"context": 16,
**_get_model_extra_info("o1-preview"),
},
{
"id": "openai-gpt-4-turbo",
"model": "gpt-4-turbo",
"name": "GPT-4 Turbo (Legacy)",
"provider": "openai",
"provider_name": "OpenAI",
"provider_brand": "openai",
"context": 8,
**_get_model_extra_info("gpt-4-turbo"),
},
]
"""
{
"id": "gpt-4o",
"created": 1715367049,
"object": "model",
"owned_by": "system"
}
"""
openai_models = (await self.openai_client.models.list()).data
models = []
for model in openai_models:
if not model.id.startswith("gpt-4") and not model.id.startswith("o1"):
# skip other models
continue
model_id = f"{self.provider}-{model.id}"
models.append(
{
"id": model_id,
"model": model.id,
"name": f"{self.provider} {model.id}",
"provider": "openai",
"provider_name": self.provider,
"provider_brand": self.provider,
"context": 16,
**_get_model_extra_info(model.id),
}
)
return {"default_models": default_models, "models": models}


Expand Down Expand Up @@ -439,7 +442,7 @@ def __generate_content(
),
)

def get_models(self):
async def get_models(self):
default_models = _get_default_model_dict("gemini-pro")
models = [
{
Expand Down Expand Up @@ -544,7 +547,7 @@ async def translate_completions(self, raycast_data: dict):
logger.error(f"Anthropic translation error: {e}")
yield f"Error: {str(e)}"

def get_models(self):
async def get_models(self):
default_models = _get_default_model_dict("claude-3-5-sonnet-20240620")
models = [
{
Expand Down Expand Up @@ -585,33 +588,36 @@ def get_models(self):
MODELS_AVAILABLE = []
DEFAULT_MODELS = {}
AVAILABLE_DEFAULT_MODELS = []
if GeminiChatBot.is_start_available():
logger.info("Google API is available")
_bot = GeminiChatBot()
_models = _bot.get_models()
MODELS_AVAILABLE.extend(_models["models"])
AVAILABLE_DEFAULT_MODELS.append(_models["default_models"])
MODELS_DICT.update({model["model"]: _bot for model in _models["models"]})
if OpenAIChatBot.is_start_available():
logger.info("OpenAI API is available")
_bot = OpenAIChatBot()
_models = _bot.get_models()
MODELS_AVAILABLE.extend(_models["models"])
AVAILABLE_DEFAULT_MODELS.append(_models["default_models"])
MODELS_DICT.update({model["model"]: _bot for model in _models["models"]})
if AnthropicChatBot.is_start_available():
logger.info("Anthropic API is available")
_bot = AnthropicChatBot()
_models = _bot.get_models()
MODELS_AVAILABLE.extend(_models["models"])
AVAILABLE_DEFAULT_MODELS.append(_models["default_models"])
MODELS_DICT.update({model["model"]: _bot for model in _models["models"]})


DEFAULT_MODELS = next(iter(AVAILABLE_DEFAULT_MODELS))
if DEFAULT_MODEL and DEFAULT_MODEL in MODELS_DICT:
DEFAULT_MODELS = MODELS_DICT[DEFAULT_MODEL]
logger.info(f"Using default model: {DEFAULT_MODEL}")


async def init_models():
global MODELS_DICT, MODELS_AVAILABLE, DEFAULT_MODELS, AVAILABLE_DEFAULT_MODELS
if GeminiChatBot.is_start_available():
logger.info("Google API is available")
_bot = GeminiChatBot()
_models = await _bot.get_models()
MODELS_AVAILABLE.extend(_models["models"])
AVAILABLE_DEFAULT_MODELS.append(_models["default_models"])
MODELS_DICT.update({model["model"]: _bot for model in _models["models"]})
if OpenAIChatBot.is_start_available():
logger.info("OpenAI API is available")
_bot = OpenAIChatBot()
_models = await _bot.get_models()
MODELS_AVAILABLE.extend(_models["models"])
AVAILABLE_DEFAULT_MODELS.append(_models["default_models"])
MODELS_DICT.update({model["model"]: _bot for model in _models["models"]})
if AnthropicChatBot.is_start_available():
logger.info("Anthropic API is available")
_bot = AnthropicChatBot()
_models = await _bot.get_models()
MODELS_AVAILABLE.extend(_models["models"])
AVAILABLE_DEFAULT_MODELS.append(_models["default_models"])
MODELS_DICT.update({model["model"]: _bot for model in _models["models"]})

DEFAULT_MODELS = next(iter(AVAILABLE_DEFAULT_MODELS))
if DEFAULT_MODEL and DEFAULT_MODEL in MODELS_DICT:
DEFAULT_MODELS = MODELS_DICT[DEFAULT_MODEL]
logger.info(f"Using default model: {DEFAULT_MODEL}")


def get_bot(model_id):
Expand Down

0 comments on commit cff9e85

Please sign in to comment.