-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
107 lines (79 loc) · 2.94 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import io
import sys
import gdown
from fastapi import (
FastAPI,
File,
UploadFile,
)
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from Inference import predict
import tempfile
import os
from utils.Translation import *
app = FastAPI()
# Function to get the model and tokenizer from Google Drive instead of putting it in the repo
def download_file_from_google_drive(file_id, output_path):
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, output_path, quiet=False)
# download_file_from_google_drive(
# "1wYF0uHMHWdWb6G2XOB6dLQj3LWyz8u5X", "./ASR_2_1_300.pth"
# )
# download_file_from_google_drive(
# "19hitohi6MgNPpTvsTqvt9fmQLWPxD9ky", "./translate_v1.pth"
# )
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins
allow_credentials=True,
allow_methods=["*"], # Allow all methods (GET, POST, etc.)
allow_headers=["*"], # Allow all headers
)
@app.get("/")
async def root():
return {"message": "Hello World"}
class TranslationRequest(BaseModel):
text: str
@app.post("/translate/auto")
async def translateOpenL(request: TranslationRequest):
response = translate_openl(request.text)
return {"translation": response}
@app.post("/translate/en")
async def translate_endpoint(request: TranslationRequest):
response = translate(request.text)
return {"translation": response}
@app.post("/translate/en-ar")
async def translate_endpoint(request: TranslationRequest):
input_text = ">>ar<<" + request.text
inputs = tokenizer.encode(input_text, return_tensors="pt")
outputs = model.generate(inputs)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"translation": response}
@app.post("/audio2text")
async def upload_audio(file: UploadFile = File(...)):
# Read the uploaded audio file into memory
contents = await file.read()
# Get the current working directory
current_dir = os.getcwd()
print(current_dir, flush=True)
# Create a temporary file in the current working directory
with tempfile.NamedTemporaryFile(
dir=current_dir, delete=False, suffix=".wav"
) as tmp_file:
tmp_file.write(contents)
tmp_file_path = tmp_file.name # Get the path of the temp file
try:
# Pass the path of the saved file to the predict function
print(f"Temporary file created at: {tmp_file_path}", flush=True)
result = predict(tmp_file_path)
finally:
# Clean up the temporary file after prediction
os.remove(tmp_file_path)
print(f"Temporary file deleted: {tmp_file_path}", flush=True)
return {"text": result}