Skip to content

Commit

Permalink
api update
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiesun committed Nov 1, 2024
1 parent 18b356c commit 57dddbd
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 16 deletions.
94 changes: 84 additions & 10 deletions common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import asyncio
import logging
import tempfile
import uuid
import aiohttp
import jwt
from openai import AsyncAzureOpenAI
import os
import os

from common.azure_blob import generate_blob_rl_sas_url, upload_blob_text, upload_blobfile


log = logging.getLogger(__name__)



def md5hash(s: str) -> str:
import hashlib
Expand All @@ -29,6 +40,14 @@ def get_openai_client():
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)

def get_openai_imagine_client():
return AsyncAzureOpenAI(
azure_endpoint=os.getenv("IMAGINE_AZURE_OPENAI_ENDPOINT"),
azure_deployment=os.getenv("IMAGINE_AZURE_OPENAI_MODEL_DEPLOYMENT_NAME"),
api_key=os.getenv("IMAGINE_AZURE_OPENAI_API_KEY"),
api_version=os.getenv("IMAGINE_AZURE_OPENAI_API_VERSION"),
)

async def openai_async_text_generate(sysmsg, prompt, model: str) -> str:
"""OpenAI API"""
client = get_openai_client()
Expand Down Expand Up @@ -88,15 +107,70 @@ async def openai_agenerate_image(
quality: str = "standard",
size: str = "1024x1024",
style: str = "vivid",
container_name: str = "images",
expiry_hours: int = 48,
):
client = get_openai_client()
response = await client.images.generate(
model="dall-e-3",
prompt=prompt,
size=size,
quality=quality,
style=style,
n=1,
)
return [d.url for d in response.data]
client = get_openai_imagine_client()
try:
response = await client.images.generate(
model="dall-e-3",
prompt=prompt,
size=size,
quality=quality,
style=style,
n=1,
)
except Exception as e:
raise RuntimeError(f"生成图片失败: {str(e)}")

# 检查生成图片的状态
if not response or not response.data:
raise RuntimeError("生成图片失败: 未返回有效数据")

# 获取生成的图片 URL
image_urls = [d.url for d in response.data]

log.info(f"openai gen image URLs: {image_urls}")

# 上传到 Azure Blob 并返回 Blob URL
blob_urls = []
for image_url in image_urls:
# 下载图片内容
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as resp:
image_content = await resp.read()

# 使用临时文件保存图片内容
temp_dir = tempfile.gettempdir()
temp_filename = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")

with open(temp_filename, "wb") as f:
f.write(image_content)

# 生成随机的 blob 名称
blob_name = f"{uuid.uuid4()}.jpg"
async def upload_images():
try:
await upload_blobfile(
container_name=container_name,
blob_name=blob_name,
filename=temp_filename,
overwrite=True,
expiry_hours=expiry_hours,
)
finally:
if os.path.exists(temp_filename):
os.remove(temp_filename)

asyncio.create_task(upload_images())

# 生成 Blob URL
blob_url = generate_blob_rl_sas_url(
container_name=container_name,
blob_name=blob_name,
expiry_hours=expiry_hours
)
blob_urls.append(blob_url)

return blob_urls

12 changes: 6 additions & 6 deletions common/azure_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,30 @@ def get_blob_service(conn_str: str = None):
return BlobServiceClient.from_connection_string(conn_str=_conn_str)


def generate_blob_rl_sas(container, blob_name, permission, expiry_hours):
def generate_blob_rl_sas(container_name, blob_name, permission, expiry_hours):
_conn_str = os.environ.get("AZURE_BLOB_CONNECT_STR")
account_name, account_key = parse_account_info(_conn_str)
sas_blob = generate_blob_sas(
account_name=account_name,
blob_name=blob_name,
container_name=container.container_name,
container_name=container_name,
account_key=account_key,
permission=permission,
expiry=datetime.now(UTC) + timedelta(hours=expiry_hours),
)
return sas_blob


def generate_blob_rl_sas_url(container, blob_name, expiry_hours):
def generate_blob_rl_sas_url(container_name, blob_name, expiry_hours):
_conn_str = os.environ.get("AZURE_BLOB_CONNECT_STR")
account_name, account_key = parse_account_info(_conn_str)
sas_blob = generate_blob_rl_sas(
container=container,
container_name,
blob_name=blob_name,
permission="rl",
expiry_hours=expiry_hours,
)
return f"https://{account_name}.blob.core.windows.net/{container.container_name}/{blob_name}?{sas_blob}"
return f"https://{account_name}.blob.core.windows.net/{container_name}/{blob_name}?{sas_blob}"


async def upload_blob_text(
Expand Down Expand Up @@ -110,7 +110,7 @@ async def upload_blobfile(
url
+ "?"
+ generate_blob_rl_sas(
container=container_client,
container_client.container_name,
blob_name=blob_name,
permission="r",
expiry_hours=expiry_hours,
Expand Down
62 changes: 62 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from common import (
md5hash,
openai_agenerate_image,
openai_analyze_image,
openai_async_text_generate,
validate_api_key,
Expand Down Expand Up @@ -295,6 +296,67 @@ async def openai_analyze_image_api(
result={},
)


class ImageGenerate(BaseModel):
prompt: str = Field(
..., description="The user's input prompt for generating an image."
)
quality: str = Field(
"standard",
description="The quality of the generated image. Defaults to 'standard'.",
)
size: str = Field(
"1024x1024",
description="The size of the generated image. Defaults to '1024x1024'.",
)
style: str = Field(
"vivid",
description="The style of the generated image. Defaults to 'vivid'.",
)
container_name: str = Field(
"images",
description="The Azure Blob container name where the image will be stored. Defaults to 'images'.",
)
expiry_hours: int = Field(
24 * 365,
description="The number of hours the image URL will be valid. Defaults to 48 hours.",
)

@app.api_route(
"/api/openai/image/generate",
methods=["POST"],
summary="image generate",
description="generate an image using openai",
)
async def openai_generate_image_api(
req: ImageGenerate,
td: TokenData = Depends(verify_api_key),
):
logging.info("openai_generate_image HTTP trigger function processed a request.")
try:
blob_urls = await openai_agenerate_image(
prompt=req.prompt,
quality=req.quality,
size=req.size,
style=req.style,
container_name=req.container_name,
expiry_hours=req.expiry_hours,
)
response = {"data": blob_urls}
return RestResult(
code=0,
msg="ok",
result=response,
)
except Exception as e:
return RestResult(
code=500,
msg=str(e),
result={},
)



if __name__ == "__main__":
import uvicorn
webport = int(os.environ.get("WEB_PORT", 8000))
Expand Down

0 comments on commit 57dddbd

Please sign in to comment.