Skip to content

Commit

Permalink
Testing Whisper pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
MrSa3dola committed Nov 7, 2024
1 parent 29d66ef commit 0710713
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ __pycache__/
ASR.pth
Test Arabic.mp3
test_main.http
*.wav
*.mp3
res.txt
# C extensions
*.so
*.pth
Expand Down
33 changes: 33 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import io
import sys

# import dotenv

# dotenv.load_dotenv()
import gdown
from fastapi import (
FastAPI,
Expand Down Expand Up @@ -105,3 +108,33 @@ async def upload_audio(file: UploadFile = File(...)):
print(f"Temporary file deleted: {tmp_file_path}", flush=True)

return {"text": result}


# @app.post("/asr_openai")
# async def asr_openai(file: UploadFile = File(...)):
# import whisper

# # Read the uploaded audio file into memory
# contents = await file.read()

# # Get the current working directory
# current_dir = os.getcwd()
# print(current_dir, flush=True)

# # Create a temporary file in the current working directory
# with tempfile.NamedTemporaryFile(
# dir=current_dir, delete=False, suffix=".wav"
# ) as tmp_file:
# tmp_file.write(contents)
# tmp_file_path = tmp_file.name # Get the path of the temp file

# try:
# # Pass the path of the saved file to the predict function
# print(f"Temporary file created at: {tmp_file_path}", flush=True)
# model = whisper.load_model("base")
# result = model.transcribe("audio.wav")
# finally:
# # Clean up the temporary file after prediction
# os.remove(tmp_file_path)
# print(f"Temporary file deleted: {tmp_file_path}", flush=True)
# return {"text": result}
16 changes: 7 additions & 9 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import torch
import whisper
import sys
import io

model = whisper.load_model("turbo")
result = model.transcribe("test6.mp3", language="ar")

def get_device():

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")
return torch.device("cpu")


get_device()
with open("res.txt", "w", encoding="utf-8") as f:
f.write(result["text"])

0 comments on commit 0710713

Please sign in to comment.