diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index 66b9b19..63643b6 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -85,6 +85,7 @@ def load_model( model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True, device: Optional[torch.device] = None, + weights_only: bool = True, ) -> Tuple[DetectionModel, ModelParameters]: """Load model from file. @@ -105,7 +106,11 @@ def load_model( if not os.path.isfile(model_path): raise FileNotFoundError("Model file not found.") - net_params = torch.load(model_path, map_location=device) + net_params = torch.load( + model_path, + map_location=device, + weights_only=weights_only, + ) params = net_params["params"] diff --git a/tests/conftest.py b/tests/conftest.py index 06f9ddc..fbebc98 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,31 @@ from pathlib import Path +from typing import List import pytest +@pytest.fixture +def example_data_dir() -> Path: + pkg_dir = Path(__file__).parent.parent + example_data_dir = pkg_dir / "example_data" + assert example_data_dir.exists() + return example_data_dir + + +@pytest.fixture +def example_audio_dir(example_data_dir: Path) -> Path: + example_audio_dir = example_data_dir / "audio" + assert example_audio_dir.exists() + return example_audio_dir + + +@pytest.fixture +def example_audio_files(example_audio_dir: Path) -> List[Path]: + audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]")) + assert len(audio_files) == 3 + return audio_files + + @pytest.fixture def data_dir() -> Path: dir = Path(__file__).parent / "data" diff --git a/tests/test_api.py b/tests/test_api.py index d28c733..e828c9e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,14 +1,13 @@ """Test bat detect module API.""" -from pathlib import Path - import os from glob import glob +from pathlib import Path import numpy as np +import soundfile as sf import torch from torch import nn -import soundfile as sf from batdetect2 import api @@ -267,7 +266,6 @@ def test_process_file_with_spec_slices(): assert len(results["spec_slices"]) == len(detections) - def test_process_file_with_empty_predictions_does_not_fail( tmp_path: Path, ): diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..3519c38 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,78 @@ +"""Test suite for model functions.""" + +import warnings +from pathlib import Path +from typing import List + +import numpy as np +from hypothesis import given, settings +from hypothesis import strategies as st + +from batdetect2 import api +from batdetect2.detector import parameters + + +def test_can_import_model_without_warnings(): + with warnings.catch_warnings(): + warnings.simplefilter("error") + api.load_model() + + +@settings(deadline=None, max_examples=5) +@given(duration=st.floats(min_value=0.1, max_value=2)) +def test_can_import_model_without_pickle(duration: float): + # NOTE: remove this test once no other issues are found This is a temporary + # test to check that change in model loading did not impact model behaviour + # in any way. + + samplerate = parameters.TARGET_SAMPLERATE_HZ + audio = np.random.rand(int(duration * samplerate)) + + model_without_pickle, model_params_without_pickle = api.load_model( + weights_only=True + ) + model_with_pickle, model_params_with_pickle = api.load_model( + weights_only=False + ) + + assert model_params_without_pickle == model_params_with_pickle + + predictions_without_pickle, _, _ = api.process_audio( + audio, + model=model_without_pickle, + ) + predictions_with_pickle, _, _ = api.process_audio( + audio, + model=model_with_pickle, + ) + + assert predictions_without_pickle == predictions_with_pickle + + +def test_can_import_model_without_pickle_on_test_data( + example_audio_files: List[Path], +): + # NOTE: remove this test once no other issues are found This is a temporary + # test to check that change in model loading did not impact model behaviour + # in any way. + + model_without_pickle, model_params_without_pickle = api.load_model( + weights_only=True + ) + model_with_pickle, model_params_with_pickle = api.load_model( + weights_only=False + ) + + assert model_params_without_pickle == model_params_with_pickle + + for audio_file in example_audio_files: + audio = api.load_audio(str(audio_file)) + predictions_without_pickle, _, _ = api.process_audio( + audio, + model=model_without_pickle, + ) + predictions_with_pickle, _, _ = api.process_audio( + audio, + model=model_with_pickle, + ) + assert predictions_without_pickle == predictions_with_pickle