Skip to content

Commit

Permalink
feat: allow longer audios
Browse files Browse the repository at this point in the history
  • Loading branch information
marwan2232004 committed Oct 30, 2024
1 parent 43cbe01 commit 000e6c4
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 36 deletions.
50 changes: 30 additions & 20 deletions Inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import librosa
import numpy as np
from utils.Audio_Processing import preprocess_audio
from utils.Constants import *
Expand All @@ -17,25 +18,9 @@
def predict(audio_file):
device = get_device()

processed_audios = []
mel_spec, duration = preprocess_audio(audio_file)
processed_audios.append(mel_spec)
padded_audios = [
(
mel_spec.shape[-1],
np.pad(
mel_spec,
((0, 0), (0, N_FRAMES - mel_spec.shape[-1])),
mode="constant",
),
)
for mel_spec in processed_audios
]

char2idx, idx2char, vocab_size = preprocess_vocab()

# load model

model = MMS(
vocab_size=vocab_size,
max_encoder_seq_len=math.ceil(N_FRAMES / 2),
Expand All @@ -47,12 +32,37 @@ def predict(audio_file):
dim_feedforward=2048,
)

model.load_state_dict(torch.load(model_path, weights_only=False,map_location=torch.device('cpu')))
model.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
model.to(device)
model.eval()

result = greedyDecoder(
model, padded_audios[0][1], padded_audios[0][0], char2idx, idx2char, device
)
audio_data, _ = librosa.load(audio_file, sr=SAMPLE_RATE) # Load the audio
n_chunks = math.ceil(audio_data.shape[0] / N_SAMPLES) # Get the number of chunks
# divide the audio into segments of 15 secs
chunk_size = audio_data.shape[0] if n_chunks == 1 else N_SAMPLES
audio_segments = [audio_data[i * chunk_size: min(audio_data.shape[0], (i + 1) * chunk_size)]
for i in range(n_chunks)]
result = ""
for audio_segment in audio_segments:

mel_spectrogram = preprocess_audio(audio_segment)

processed_audios = [mel_spectrogram]

padded_audios = [
(
mel_spec.shape[-1],
np.pad(
mel_spec,
((0, 0), (0, N_FRAMES - mel_spec.shape[-1])),
mode="constant",
),
)
for mel_spec in processed_audios
]

result += " " + greedyDecoder(
model, padded_audios[0][1], padded_audios[0][0], char2idx, idx2char, device
)

return result
22 changes: 8 additions & 14 deletions utils/Audio_Processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,17 @@ def pad_or_trim(array, length=N_SAMPLES, axis=-1, padding=True):


# Function to load and preprocess audio
def preprocess_audio(file_path):
audio_data, _ = librosa.load(file_path, sr=SAMPLE_RATE)
def preprocess_audio(audio_data):
spectrogram = librosa.stft(y=audio_data, n_fft=N_FFT, hop_length=HOP_LENGTH)

duration = librosa.get_duration(y=audio_data, sr=SAMPLE_RATE)
spectrogram_mag, _ = librosa.magphase(spectrogram)

modified_audio = pad_or_trim(audio_data, padding=False)

sgram = librosa.stft(y=modified_audio, n_fft=N_FFT, hop_length=HOP_LENGTH)

sgram_mag, _ = librosa.magphase(sgram)

mel_scale_sgram = librosa.feature.melspectrogram(
S=sgram_mag, sr=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS
mel_scale_spectrogram = librosa.feature.melspectrogram(
S=spectrogram_mag, sr=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS
)

mel_sgram = librosa.amplitude_to_db(mel_scale_sgram, ref=np.min)
mel_spectrogram = librosa.amplitude_to_db(mel_scale_spectrogram, ref=np.min)

del audio_data, modified_audio, sgram, mel_scale_sgram
del spectrogram, mel_scale_spectrogram, spectrogram_mag

return mel_sgram, duration
return mel_spectrogram
3 changes: 2 additions & 1 deletion utils/MMS.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

def get_device():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
print(f"Using device: {device}")
return torch.device("cpu")
return device


def get_conv_Lout(L_in, conv):
Expand Down
1 change: 0 additions & 1 deletion utils/NLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def preprocess_vocab():
idx2char = {idx: char for idx, char in enumerate(vocab)}

vocab_size = len(vocab)
print(f"Vocabulary size: {vocab_size}")
return char2idx, idx2char, vocab_size


Expand Down

0 comments on commit 000e6c4

Please sign in to comment.