-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
30 lines (25 loc) · 990 Bytes
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
app = FastAPI()
class Query(BaseModel):
text: str
# Load the pre-trained model and tokenizer
model_name = "Saidtaoussi/AraT5_Darija_to_MSA"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
@app.post("/translate")
async def translate_text(query: Query):
try:
# Tokenize the input text
inputs = tokenizer(query.text, return_tensors="pt", padding=True)
# Generate translation
translated = model.generate(**inputs)
# Decode the translation
output_text = tokenizer.decode(translated[0], skip_special_tokens=True)
return {"response": output_text}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)