Skip to content

Commit

Permalink
Merge pull request #36 from macaodha/fix/GH-31-negative-dimension-are…
Browse files Browse the repository at this point in the history
…-not-allowed

fix: Resolve detect Command Failure with Specific Audio Files (GH-31)
  • Loading branch information
mbsantiago authored Nov 10, 2024
2 parents 697b5db + 505cca2 commit 7dc2869
Show file tree
Hide file tree
Showing 16 changed files with 416 additions and 60 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,6 @@ experiments/*
!batdetect2_notebook.ipynb
!batdetect2/models/*.pth.tar
!tests/data/*.wav
!tests/data/**/*.wav
notebooks/lightning_logs
example_data/preprocessed
7 changes: 6 additions & 1 deletion batdetect2/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
__version__ = '1.0.8'
import logging

numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)

__version__ = "1.0.8"
2 changes: 0 additions & 2 deletions batdetect2/train/train_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import glob
import json
import os
import random

import numpy as np

Expand Down
3 changes: 2 additions & 1 deletion batdetect2/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Types used in the code base."""

from typing import List, NamedTuple, Optional, Union

import numpy as np
Expand All @@ -17,7 +18,7 @@


try:
from typing import NotRequired
from typing import NotRequired # type: ignore
except ImportError:
from typing_extensions import NotRequired

Expand Down
201 changes: 147 additions & 54 deletions batdetect2/utils/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
import torch

from batdetect2.detector import parameters

from . import wavfile

__all__ = [
Expand All @@ -15,18 +17,42 @@
]


def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
noverlap = np.floor(fft_overlap * nfft)
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
def time_to_x_coords(
time_in_file: float,
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
) -> float:
nfft = np.floor(window_duration * samplerate) # int() uses floor
noverlap = np.floor(window_overlap * nfft)
return (time_in_file * samplerate - noverlap) / (nfft - noverlap)


def x_coords_to_time(
x_pos: int,
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
) -> float:
n_fft = np.floor(window_duration * samplerate)
n_overlap = np.floor(window_overlap * n_fft)
n_step = n_fft - n_overlap
return ((x_pos * n_step) + n_overlap) / samplerate
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window


# NOTE this is also defined in post_process
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate)
noverlap = np.floor(fft_overlap * nfft)
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
def x_coord_to_sample(
x_pos: int,
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
resize_factor: float = parameters.RESIZE_FACTOR,
) -> int:
n_fft = np.floor(window_duration * samplerate)
n_overlap = np.floor(window_overlap * n_fft)
n_step = n_fft - n_overlap
x_pos = int(x_pos / resize_factor)
return int((x_pos * n_step) + n_overlap)


def generate_spectrogram(
Expand Down Expand Up @@ -184,55 +210,118 @@ def load_audio(
return sampling_rate, audio_raw


def compute_spectrogram_width(
length: int,
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
resize_factor: float = parameters.RESIZE_FACTOR,
) -> int:
n_fft = int(window_duration * samplerate)
n_overlap = int(window_overlap * n_fft)
n_step = n_fft - n_overlap
width = (length - n_overlap) // n_step
return int(width * resize_factor)


def pad_audio(
audio_raw,
fs,
ms,
overlap_perc,
resize_factor,
divide_factor,
fixed_width=None,
audio: np.ndarray,
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
resize_factor: float = parameters.RESIZE_FACTOR,
divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
fixed_width: Optional[int] = None,
):
# Adds zeros to the end of the raw data so that the generated sepctrogram
# will be evenly divisible by `divide_factor`
# Also deals with very short audio clips and fixed_width during training
"""Pad audio to be evenly divisible by `divide_factor`.
This function pads the audio signal with zeros to ensure that the
generated spectrogram length will be evenly divisible by `divide_factor`.
This is important for the model to work correctly.
This `divide_factor` comes from the model architecture as it downscales
the spectrogram by this factor, so the input must be divisible by this
integer number.
Parameters
----------
audio : np.ndarray
The audio signal.
samplerate : int
The sampling rate of the audio signal.
window_size : float
The window size in seconds used for the spectrogram computation.
window_overlap : float
The overlap between windows in the spectrogram computation.
resize_factor : float
This factor is used to resize the spectrogram after the STFT
computation. Default is 0.5 which means that the spectrogram will be
reduced by half. Important to take into account for the final size of
the spectrogram.
divide_factor : int
The factor by which the spectrogram will be divided.
fixed_width : int, optional
If provided, the audio will be padded or cut so that the resulting
spectrogram width will be equal to this value.
Returns
-------
np.ndarray
The padded audio signal.
"""
spec_width = compute_spectrogram_width(
audio.shape[0],
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)

# This code could be clearer, clean up
nfft = int(ms * fs)
noverlap = int(overlap_perc * nfft)
step = nfft - noverlap
min_size = int(divide_factor * (1.0 / resize_factor))
spec_width = (audio_raw.shape[0] - noverlap) // step
spec_width_rs = spec_width * resize_factor

if fixed_width is not None and spec_width < fixed_width:
# too small
# used during training to ensure all the batches are the same size
diff = fixed_width * step + noverlap - audio_raw.shape[0]
audio_raw = np.hstack(
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
if fixed_width:
target_samples = x_coord_to_sample(
fixed_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)

elif fixed_width is not None and spec_width > fixed_width:
# too big
# used during training to ensure all the batches are the same size
diff = fixed_width * step + noverlap - audio_raw.shape[0]
audio_raw = audio_raw[:diff]

elif (
spec_width_rs < min_size
or (np.floor(spec_width_rs) % divide_factor) != 0
):
# need to be at least min_size
div_amt = np.ceil(spec_width_rs / float(divide_factor))
div_amt = np.maximum(1, div_amt)
target_size = int(div_amt * divide_factor * (1.0 / resize_factor))
diff = target_size * step + noverlap - audio_raw.shape[0]
audio_raw = np.hstack(
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
)
if spec_width < fixed_width:
# need to be at least min_size
diff = target_samples - audio.shape[0]
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))

if spec_width > fixed_width:
return audio[:target_samples]

return audio_raw
return audio

min_width = int(divide_factor / resize_factor)

if spec_width < min_width:
target_samples = x_coord_to_sample(
min_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)
diff = target_samples - audio.shape[0]
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))

if (spec_width % divide_factor) == 0:
return audio

target_width = int(np.ceil(spec_width / divide_factor)) * divide_factor
target_samples = x_coord_to_sample(
target_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)
diff = target_samples - audio.shape[0]
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))


def gen_mag_spectrogram(x, fs, ms, overlap_perc):
Expand All @@ -247,7 +336,11 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):

# compute spec
spec, _ = librosa.core.spectrum._spectrogram(
y=x, power=1, n_fft=nfft, hop_length=step, center=False
y=x,
power=1,
n_fft=nfft,
hop_length=step,
center=False,
)

# remove DC component and flip vertical orientation
Expand Down
4 changes: 2 additions & 2 deletions batdetect2/utils/detector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
try:
from numpy.exceptions import AxisError
except ImportError:
from numpy import AxisError
from numpy import AxisError # type: ignore

import batdetect2.detector.compute_features as feats
import batdetect2.detector.post_process as pp
Expand Down Expand Up @@ -759,7 +759,7 @@ def process_file(

# Get original sampling rate
file_samp_rate = librosa.get_samplerate(audio_file)
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)

# load audio file
sampling_rate, audio_full = au.load_audio(
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ batdetect2 = "batdetect2.cli:cli"

[tool.uv]
dev-dependencies = [
"debugpy>=1.8.8",
"hypothesis>=6.118.7",
"pyright>=1.1.388",
"pytest>=7.2.2",
"ruff>=0.7.3",
Expand Down
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pathlib import Path

import pytest


@pytest.fixture
def data_dir() -> Path:
dir = Path(__file__).parent / "data"
assert dir.exists()
return dir


@pytest.fixture
def contrib_dir(data_dir) -> Path:
dir = data_dir / "contrib"
assert dir.exists()
return dir
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 7dc2869

Please sign in to comment.