Skip to content

Commit

Permalink
add dataset tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Gautzilla committed Dec 16, 2024
1 parent ac1d2e2 commit c6b5746
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,5 @@ select = ["ALL"]
"D", # Docstring-related stuff
"SLF001", # Access to private variables
"BLE001", # Blind exceptions
"PLR0913", # Too many arguments in methods
]
2 changes: 1 addition & 1 deletion src/OSmOSE/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
OSMOSE_PATH = namedtuple("path_list", __global_path_dict.keys())(**__global_path_dict)

TIMESTAMP_FORMAT_AUDIO_FILE = "%Y-%m-%dT%H:%M:%S.%f%z"
TIMESTAMP_FORMAT_TEST_FILES = "%y%m%d%H%M%S"
TIMESTAMP_FORMAT_TEST_FILES = "%y%m%d%H%M%S%f"
TIMESTAMP_FORMAT_EXPORTED_FILES = "%Y_%m_%d_%H_%M_%S"
FPDEFAULT = 0o664 # Default file permissions
DPDEFAULT = stat.S_ISGID | 0o775 # Default directory permissions
Expand Down
6 changes: 4 additions & 2 deletions src/OSmOSE/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

from __future__ import annotations

from pathlib import Path
from typing import Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar

from pandas import Timedelta, Timestamp, date_range

from OSmOSE.data.base_data import BaseData
from OSmOSE.data.base_file import BaseFile

if TYPE_CHECKING:
from pathlib import Path

TData = TypeVar("TData", bound=BaseData)
TFile = TypeVar("TFile", bound=BaseFile)

Expand Down
94 changes: 94 additions & 0 deletions tests/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from OSmOSE.config import TIMESTAMP_FORMAT_TEST_FILES
from OSmOSE.data.audio_data import AudioData
from OSmOSE.data.audio_dataset import AudioDataset
from OSmOSE.data.audio_file import AudioFile
from OSmOSE.data.audio_item import AudioItem
from OSmOSE.utils.audio_utils import generate_sample_audio
Expand Down Expand Up @@ -399,3 +400,96 @@ def test_audio_resample_sample_count(
data = AudioData.from_files(audio_files, begin=start, end=stop)
data.sample_rate = sample_rate
assert data.get_value().shape[0] == expected_nb_samples


@pytest.mark.parametrize(
("audio_files", "begin", "end", "duration", "expected_audio_data"),
[
pytest.param(
{
"duration": 1,
"sample_rate": 48_000,
"nb_files": 1,
"date_begin": pd.Timestamp("2024-01-01 12:00:00"),
"series_type": "increase",
},
None,
None,
None,
generate_sample_audio(1, 48_000),
id="one_entire_file",
),
pytest.param(
{
"duration": 1,
"sample_rate": 48_000,
"nb_files": 3,
"date_begin": pd.Timestamp("2024-01-01 12:00:00"),
"series_type": "increase",
},
None,
None,
pd.Timedelta(seconds=1),
generate_sample_audio(
nb_files=3, nb_samples=48_000, series_type="increase"
),
id="multiple_consecutive_files",
),
pytest.param(
{
"duration": 1,
"sample_rate": 48_000,
"nb_files": 2,
"inter_file_duration": 1,
"date_begin": pd.Timestamp("2024-01-01 12:00:00"),
"series_type": "increase",
},
None,
None,
pd.Timedelta(seconds=1),
[
generate_sample_audio(nb_files=1, nb_samples=96_000)[0][0:48_000],
generate_sample_audio(
nb_files=1, nb_samples=48_000, min_value=0.0, max_value=0.0
)[0],
generate_sample_audio(nb_files=1, nb_samples=96_000)[0][48_000:],
],
id="two_separated_files",
),
pytest.param(
{
"duration": 1,
"sample_rate": 48_000,
"nb_files": 3,
"inter_file_duration": -0.5,
"date_begin": pd.Timestamp("2024-01-01 12:00:00"),
"series_type": "repeat",
},
None,
None,
pd.Timedelta(seconds=1),
generate_sample_audio(nb_files=2, nb_samples=48_000),
id="overlapping_files",
),
],
indirect=["audio_files"],
)
def test_audio_dataset_from_folder(
tmp_path: Path,
audio_files: tuple[list[Path], pytest.fixtures.Subrequest],
begin: pd.Timestamp | None,
end: pd.Timestamp | None,
duration: pd.Timedelta | None,
expected_audio_data: list[tuple[int, bool]],
) -> None:
dataset = AudioDataset.from_folder(
tmp_path,
strptime_format=TIMESTAMP_FORMAT_TEST_FILES,
begin=begin,
end=end,
data_duration=duration,
)
assert all(
np.array_equal(data.get_value(), expected)
for (data, expected) in zip(dataset.data, expected_audio_data)
)

0 comments on commit c6b5746

Please sign in to comment.