Skip to content

Commit

Permalink
Added new Translation model EN-AR
Browse files Browse the repository at this point in the history
  • Loading branch information
MrSa3dola committed Nov 1, 2024
1 parent 000e6c4 commit d8d5c4a
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,21 @@

# Function to get the model and tokenizer from Google Drive instead of putting it in the repo
def download_file_from_google_drive(file_id, output_path):
url = f'https://drive.google.com/uc?id={file_id}'
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, output_path, quiet=False)


download_file_from_google_drive('1wYF0uHMHWdWb6G2XOB6dLQj3LWyz8u5X', './ASR_2_1_300.pth')
download_file_from_google_drive('19hitohi6MgNPpTvsTqvt9fmQLWPxD9ky', './translate_v1.pth')
# download_file_from_google_drive(
# "1wYF0uHMHWdWb6G2XOB6dLQj3LWyz8u5X", "./ASR_2_1_300.pth"
# )
# download_file_from_google_drive(
# "19hitohi6MgNPpTvsTqvt9fmQLWPxD9ky", "./translate_v1.pth"
# )

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")

sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")

Expand Down Expand Up @@ -61,6 +70,15 @@ async def translate_endpoint(request: TranslationRequest):
return {"translation": response}


@app.post("/translate/en-ar")
async def translate_endpoint(request: TranslationRequest):
input_text = ">>ar<<" + request.text
inputs = tokenizer.encode(input_text, return_tensors="pt")
outputs = model.generate(inputs)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"translation": response}


@app.post("/audio2text")
async def upload_audio(file: UploadFile = File(...)):
# Read the uploaded audio file into memory
Expand All @@ -72,7 +90,7 @@ async def upload_audio(file: UploadFile = File(...)):

# Create a temporary file in the current working directory
with tempfile.NamedTemporaryFile(
dir=current_dir, delete=False, suffix=".wav"
dir=current_dir, delete=False, suffix=".wav"
) as tmp_file:
tmp_file.write(contents)
tmp_file_path = tmp_file.name # Get the path of the temp file
Expand Down

0 comments on commit d8d5c4a

Please sign in to comment.