From 000e6c4002d4e398051e675b853b02bb7a76e3dd Mon Sep 17 00:00:00 2001 From: marwan2232004 <118024824+marwan2232004@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:44:49 +0300 Subject: [PATCH] feat: allow longer audios --- Inference.py | 50 +++++++++++++++++++++++---------------- utils/Audio_Processing.py | 22 +++++++---------- utils/MMS.py | 3 ++- utils/NLP.py | 1 - 4 files changed, 40 insertions(+), 36 deletions(-) diff --git a/Inference.py b/Inference.py index 657ad50..e294258 100644 --- a/Inference.py +++ b/Inference.py @@ -1,3 +1,4 @@ +import librosa import numpy as np from utils.Audio_Processing import preprocess_audio from utils.Constants import * @@ -17,25 +18,9 @@ def predict(audio_file): device = get_device() - processed_audios = [] - mel_spec, duration = preprocess_audio(audio_file) - processed_audios.append(mel_spec) - padded_audios = [ - ( - mel_spec.shape[-1], - np.pad( - mel_spec, - ((0, 0), (0, N_FRAMES - mel_spec.shape[-1])), - mode="constant", - ), - ) - for mel_spec in processed_audios - ] - char2idx, idx2char, vocab_size = preprocess_vocab() # load model - model = MMS( vocab_size=vocab_size, max_encoder_seq_len=math.ceil(N_FRAMES / 2), @@ -47,12 +32,37 @@ def predict(audio_file): dim_feedforward=2048, ) - model.load_state_dict(torch.load(model_path, weights_only=False,map_location=torch.device('cpu'))) + model.load_state_dict(torch.load(model_path, weights_only=False, map_location=device)) model.to(device) model.eval() - result = greedyDecoder( - model, padded_audios[0][1], padded_audios[0][0], char2idx, idx2char, device - ) + audio_data, _ = librosa.load(audio_file, sr=SAMPLE_RATE) # Load the audio + n_chunks = math.ceil(audio_data.shape[0] / N_SAMPLES) # Get the number of chunks + # divide the audio into segments of 15 secs + chunk_size = audio_data.shape[0] if n_chunks == 1 else N_SAMPLES + audio_segments = [audio_data[i * chunk_size: min(audio_data.shape[0], (i + 1) * chunk_size)] + for i in range(n_chunks)] + result = "" + for audio_segment in audio_segments: + + mel_spectrogram = preprocess_audio(audio_segment) + + processed_audios = [mel_spectrogram] + + padded_audios = [ + ( + mel_spec.shape[-1], + np.pad( + mel_spec, + ((0, 0), (0, N_FRAMES - mel_spec.shape[-1])), + mode="constant", + ), + ) + for mel_spec in processed_audios + ] + + result += " " + greedyDecoder( + model, padded_audios[0][1], padded_audios[0][0], char2idx, idx2char, device + ) return result diff --git a/utils/Audio_Processing.py b/utils/Audio_Processing.py index 0fe19f0..c49bad2 100644 --- a/utils/Audio_Processing.py +++ b/utils/Audio_Processing.py @@ -16,23 +16,17 @@ def pad_or_trim(array, length=N_SAMPLES, axis=-1, padding=True): # Function to load and preprocess audio -def preprocess_audio(file_path): - audio_data, _ = librosa.load(file_path, sr=SAMPLE_RATE) +def preprocess_audio(audio_data): + spectrogram = librosa.stft(y=audio_data, n_fft=N_FFT, hop_length=HOP_LENGTH) - duration = librosa.get_duration(y=audio_data, sr=SAMPLE_RATE) + spectrogram_mag, _ = librosa.magphase(spectrogram) - modified_audio = pad_or_trim(audio_data, padding=False) - - sgram = librosa.stft(y=modified_audio, n_fft=N_FFT, hop_length=HOP_LENGTH) - - sgram_mag, _ = librosa.magphase(sgram) - - mel_scale_sgram = librosa.feature.melspectrogram( - S=sgram_mag, sr=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS + mel_scale_spectrogram = librosa.feature.melspectrogram( + S=spectrogram_mag, sr=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS ) - mel_sgram = librosa.amplitude_to_db(mel_scale_sgram, ref=np.min) + mel_spectrogram = librosa.amplitude_to_db(mel_scale_spectrogram, ref=np.min) - del audio_data, modified_audio, sgram, mel_scale_sgram + del spectrogram, mel_scale_spectrogram, spectrogram_mag - return mel_sgram, duration \ No newline at end of file + return mel_spectrogram diff --git a/utils/MMS.py b/utils/MMS.py index b387cda..670382e 100644 --- a/utils/MMS.py +++ b/utils/MMS.py @@ -7,8 +7,9 @@ def get_device(): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device("cpu") print(f"Using device: {device}") - return torch.device("cpu") + return device def get_conv_Lout(L_in, conv): diff --git a/utils/NLP.py b/utils/NLP.py index 7091560..06fc11d 100644 --- a/utils/NLP.py +++ b/utils/NLP.py @@ -21,7 +21,6 @@ def preprocess_vocab(): idx2char = {idx: char for idx, char in enumerate(vocab)} vocab_size = len(vocab) - print(f"Vocabulary size: {vocab_size}") return char2idx, idx2char, vocab_size