-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
100 lines (79 loc) · 2.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import ujson
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, Response
import httpx
from loguru import logger
from urllib.parse import urlparse
from fastapi.middleware.cors import CORSMiddleware
from utilities.openai_tool import openai_stream
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源
allow_credentials=True,
allow_methods=["*"], # 允许所有HTTP方法
allow_headers=["*"], # 允许所有HTTP头
)
parsed_url = urlparse(os.environ['OPENAI_BASE_URL'])
TARGET_URL = f"{parsed_url.scheme}://{parsed_url.netloc}"
# 全局唯一的 httpx.AsyncClient 实例
client: None | httpx.AsyncClient = None
@app.on_event("startup")
async def startup_event():
global client
client = httpx.AsyncClient()
logger.info("HTTP client initialized")
@app.on_event("shutdown")
async def shutdown_event():
global client
await client.aclose()
logger.info("HTTP client closed")
@app.middleware("http")
async def proxy_middleware(request: Request, call_next):
# 构建目标URL
url = f"{TARGET_URL}{request.url.path}"
logger.debug(f"target url: {url}")
# 获取请求方法
method = request.method
# 获取请求头
headers = dict(request.headers)
parsed_url = urlparse(url)
host = parsed_url.netloc
headers["host"] = headers["x-forwarded-host"] = host
if method == "POST" and "/v1/chat/completions" == request.url.path:
data = await request.json()
logger.debug(f"{data=}")
stream_gen = await openai_stream(data=data, path=request.url.path, channel="openai")
if data.get("stream", False):
resp = StreamingResponse(stream_gen, media_type="text/event-stream")
else:
resp = Response(ujson.dumps(stream_gen, ensure_ascii=False), status_code=200,
headers={"content-Type": "application/json"})
resp.headers["Access-Control-Allow-Origin"] = "*"
return resp
elif ((await request.body()) and (await request.json()).get("stream")) or request.query_params.get("stream"):
# 通用流式处理,应该基本没啥用
data = await request.json()
logger.debug(f"{data=}")
resp = StreamingResponse(
openai_stream(data=data, method=method, path=request.url.path, channel="httpx"),
media_type="text/event-stream")
resp.headers["Access-Control-Allow-Origin"] = "*"
return resp
else:
# 获取请求体
body = await request.body()
# 发送请求到目标服务
response = await client.request(
method=method,
url=url,
headers=headers,
content=body,
params=request.query_params
)
# 构建响应
return Response(await response.aread(), status_code=response.status_code, headers=response.headers)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)