Skip to content

Commit

Permalink
refactor: (sync) refactor sync endpoints with improvements in validat…
Browse files Browse the repository at this point in the history
…ion and storage path usage
  • Loading branch information
yufeikang committed May 27, 2024
1 parent 87a7c9f commit d59309e
Showing 1 changed file with 32 additions and 51 deletions.
83 changes: 32 additions & 51 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
import json
import logging
import os
from datetime import UTC, datetime
from pathlib import Path

import google.generativeai as genai
import httpx
import openai
from fastapi import FastAPI, Request, Response
from fastapi import FastAPI, Query, Request, Response
from fastapi.responses import StreamingResponse
from google.generativeai import GenerativeModel

from app.utils import (
ProxyRequest,
pass_through_request,
json_dumps,
pass_through_request,
process_custom_mapping,
)
from datetime import datetime

logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))

Expand All @@ -37,6 +37,10 @@

MAX_TOKENS = os.environ.get("MAX_TOKENS", 1024)

SYNC_DIR = os.environ.get("SYNC_DIR", "./sync")
if not os.path.exists(SYNC_DIR):
os.makedirs(SYNC_DIR)


def _get_default_model_dict(model_name: str):
return {
Expand Down Expand Up @@ -491,13 +495,16 @@ def check_auth(request: Request):
return False
return True


def get_current_user_email(request: Request):
bearer_token = request.headers.get("Authorization", "").split(" ")[1]
return USER_SESSION.get(bearer_token)


def get_current_utc_time():
# 获取当前时间
current_time = datetime.utcnow()
# 获取当前UTC时间并转换为ISO 8601格式,末尾手动添加'Z'表示UTC时间
return datetime.now(UTC).isoformat(timespec="milliseconds") + "Z"

# 转换为ISO 8601格式,末尾添加'Z'表示UTC时间
iso_format_time = current_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
return iso_format_time

@app.on_event("shutdown")
async def shutdown_event():
Expand Down Expand Up @@ -597,28 +604,14 @@ async def proxy_models(request: Request):
headers=response.headers,
)


@app.api_route("/api/v1/me/sync", methods=["GET"])
async def proxy_sync_get(request: Request, after: str = Query(None)):
bearer_token = request.headers.get("Authorization", "").split(" ")[1]
email = USER_SESSION[bearer_token]
try:
raycast_data = await request.json()
if not check_auth(request):
return Response(status_code=401)
except json.decoder.JSONDecodeError:
if bearer_token not in USER_SESSION:
logger.warn(f"User not in session: {bearer_token}")
return False
if email not in ALLOWED_USERS:
logger.debug(f"Allowed users: {ALLOWED_USERS}")
logger.warn(f"User not allowed: {email}")
return Response(status_code=401)



if not os.path.exists("./sync"):
os.makedirs("./sync")
if os.path.exists(f"./sync/{email}.json"):
if not check_auth(request):
return Response(status_code=401)
email = get_current_user_email(request)
target = f"{SYNC_DIR}/{email}.json"
if os.path.exists(target):
with open(f"./sync/{email}.json", "r") as f:
data = json.loads(f.read())

Expand All @@ -640,27 +633,15 @@ async def proxy_sync_get(request: Request, after: str = Query(None)):

@app.api_route("/api/v1/me/sync", methods=["PUT"])
async def proxy_sync_put(request: Request):
global email
bearer_token = request.headers.get("Authorization", "").split(" ")[1]
email = USER_SESSION[bearer_token]
try:
raycast_data = await request.json()
if not check_auth(request):
return Response(status_code=401)
except json.decoder.JSONDecodeError:
if bearer_token not in USER_SESSION:
logger.warn(f"User not in session: {bearer_token}")
return False
if email not in ALLOWED_USERS:
logger.debug(f"Allowed users: {ALLOWED_USERS}")
logger.warn(f"User not allowed: {email}")
return Response(status_code=401)

# 检查是否存在 ./sync 目录
if not os.path.exists("./sync"):
os.makedirs("./sync")
if not check_auth(request):
return Response(status_code=401)
email = get_current_user_email(request)

data = await request.body()
if not os.path.exists(f"./sync/{email}.json"):

target = f"{SYNC_DIR}/{email}.json"

if not os.path.exists(target):
# 移除 request.body 中的 deleted 字段
data = json.loads(data)
data["deleted"] = []
Expand All @@ -670,11 +651,11 @@ async def proxy_sync_put(request: Request):
item["created_at"] = item["client_updated_at"]
item["updated_at"] = updated_time
data = json.dumps(data)
with open(f"./sync/{email}.json", "w") as f:
with open(target, "w") as f:
f.write(data)

else:
with open(f"./sync/{email}.json", "r") as f:
with open(target, "r") as f:
old_data = json.loads(f.read())
new_data = json.loads(data)
# 查找 old_data["updated"] 字段中是否存在 id 与 new_data["deleted"] 字段的列表中的 id 相同的元素
Expand All @@ -700,7 +681,7 @@ async def proxy_sync_put(request: Request):
"deleted": [],
}

with open(f"./sync/{email}.json", "w") as f:
with open(target, "w") as f:
f.write(json.dumps(new_data))

return Response(json.dumps({"updated_at": updated_time}))
Expand Down

0 comments on commit d59309e

Please sign in to comment.