Skip to content

Commit

Permalink
Added test to ensure pad_audio function, and utils, are working as ex…
Browse files Browse the repository at this point in the history
…pected
  • Loading branch information
mbsantiago committed Nov 10, 2024
1 parent a4b22d6 commit 7906842
Showing 1 changed file with 136 additions and 0 deletions.
136 changes: 136 additions & 0 deletions tests/test_audio_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import numpy as np
import torch
import torch.nn.functional as F
from hypothesis import given
from hypothesis import strategies as st

from batdetect2.detector import parameters
from batdetect2.utils import audio_utils, detector_utils


@given(duration=st.floats(min_value=0.1, max_value=2))
def test_can_compute_correct_spectrogram_width(duration: float):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS

length = int(duration * samplerate)
audio = np.random.rand(length)

spectrogram, _ = audio_utils.generate_spectrogram(
audio,
samplerate,
params,
)

# convert to pytorch
spectrogram = torch.from_numpy(spectrogram)

# add batch and channel dimensions
spectrogram = spectrogram.unsqueeze(0).unsqueeze(0)

# resize the spec
resize_factor = params["resize_factor"]
spec_op_shape = (
int(params["spec_height"] * resize_factor),
int(spectrogram.shape[-1] * resize_factor),
)
spectrogram = F.interpolate(
spectrogram,
size=spec_op_shape,
mode="bilinear",
align_corners=False,
)

expected_width = audio_utils.compute_spectrogram_width(
length,
samplerate=parameters.TARGET_SAMPLERATE_HZ,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
)

assert spectrogram.shape[-1] == expected_width


@given(duration=st.floats(min_value=0.1, max_value=2))
def test_pad_audio_without_fixed_size(duration: float):
# Test the pad_audio function
# This function is used to pad audio with zeros to a specific length
# It is used in the generate_spectrogram function
# The function is tested with a simplepas
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS

length = int(duration * samplerate)
audio = np.random.rand(length)

# pad the audio to be divisible by divide factor
padded_audio = audio_utils.pad_audio(
audio,
samplerate=samplerate,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
divide_factor=params["spec_divide_factor"],
)

# check that the padded audio is divisible by the divide factor
expected_width = audio_utils.compute_spectrogram_width(
len(padded_audio),
samplerate=parameters.TARGET_SAMPLERATE_HZ,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
)

assert expected_width % params["spec_divide_factor"] == 0


@given(duration=st.floats(min_value=0.1, max_value=2))
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
duration: float,
):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
length = int(duration * samplerate)
audio = np.random.rand(length)
_, spectrogram, _ = detector_utils.compute_spectrogram(
audio,
samplerate,
params,
torch.device("cpu"),
)
assert spectrogram.shape[-1] % params["spec_divide_factor"] == 0


@given(
duration=st.floats(min_value=0.1, max_value=2),
width=st.integers(min_value=128, max_value=1024),
)
def test_pad_audio_with_fixed_width(duration: float, width: int):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS

length = int(duration * samplerate)
audio = np.random.rand(length)

# pad the audio to be divisible by divide factor
padded_audio = audio_utils.pad_audio(
audio,
samplerate=samplerate,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
divide_factor=params["spec_divide_factor"],
fixed_width=width,
)

# check that the padded audio is divisible by the divide factor
expected_width = audio_utils.compute_spectrogram_width(
len(padded_audio),
samplerate=parameters.TARGET_SAMPLERATE_HZ,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
)
assert expected_width == width

0 comments on commit 7906842

Please sign in to comment.