From d085b3212c8f48e959aee4b1dd860d7bac43ac0c Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 11 Nov 2024 11:46:06 +0000 Subject: [PATCH 1/3] Added weights_only argument to model loading function --- batdetect2/utils/detector_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index 66b9b19..63643b6 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -85,6 +85,7 @@ def load_model( model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True, device: Optional[torch.device] = None, + weights_only: bool = True, ) -> Tuple[DetectionModel, ModelParameters]: """Load model from file. @@ -105,7 +106,11 @@ def load_model( if not os.path.isfile(model_path): raise FileNotFoundError("Model file not found.") - net_params = torch.load(model_path, map_location=device) + net_params = torch.load( + model_path, + map_location=device, + weights_only=weights_only, + ) params = net_params["params"] From 394c66a2ee8c8b934217d8336b914bffc08b229a Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 11 Nov 2024 11:46:27 +0000 Subject: [PATCH 2/3] Added test to validate that changing model loading behaviour did not change model predictions --- tests/test_api.py | 6 ++---- tests/test_model.py | 47 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 tests/test_model.py diff --git a/tests/test_api.py b/tests/test_api.py index d28c733..e828c9e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,14 +1,13 @@ """Test bat detect module API.""" -from pathlib import Path - import os from glob import glob +from pathlib import Path import numpy as np +import soundfile as sf import torch from torch import nn -import soundfile as sf from batdetect2 import api @@ -267,7 +266,6 @@ def test_process_file_with_spec_slices(): assert len(results["spec_slices"]) == len(detections) - def test_process_file_with_empty_predictions_does_not_fail( tmp_path: Path, ): diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..7e5d997 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,47 @@ +"""Test suite for model functions.""" + +import warnings + +import numpy as np +from hypothesis import given, settings +from hypothesis import strategies as st + +from batdetect2 import api +from batdetect2.detector import parameters + + +def test_can_import_model_without_warnings(): + with warnings.catch_warnings(): + warnings.simplefilter("error") + api.load_model() + + +@settings(deadline=None, max_examples=5) +@given(duration=st.floats(min_value=0.1, max_value=2)) +def test_can_import_model_without_pickle(duration: float): + # NOTE: remove this test once no other issues are found This is a temporary + # test to check that change in model loading did not impact model behaviour + # in any way. + + samplerate = parameters.TARGET_SAMPLERATE_HZ + audio = np.random.rand(int(duration * samplerate)) + + model_without_pickle, model_params_without_pickle = api.load_model( + weights_only=True + ) + model_with_pickle, model_params_with_pickle = api.load_model( + weights_only=False + ) + + assert model_params_without_pickle == model_params_with_pickle + + predictions_without_pickle, _, _ = api.process_audio( + audio, + model=model_without_pickle, + ) + predictions_with_pickle, _, _ = api.process_audio( + audio, + model=model_with_pickle, + ) + + assert predictions_without_pickle == predictions_with_pickle From 3477d7b5b4dbcc2a1d9b71f0fe95fba4c573bfad Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 11 Nov 2024 11:57:46 +0000 Subject: [PATCH 3/3] Run the same test with example data instead of random audio --- tests/conftest.py | 23 +++++++++++++++++++++++ tests/test_model.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 06f9ddc..fbebc98 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,31 @@ from pathlib import Path +from typing import List import pytest +@pytest.fixture +def example_data_dir() -> Path: + pkg_dir = Path(__file__).parent.parent + example_data_dir = pkg_dir / "example_data" + assert example_data_dir.exists() + return example_data_dir + + +@pytest.fixture +def example_audio_dir(example_data_dir: Path) -> Path: + example_audio_dir = example_data_dir / "audio" + assert example_audio_dir.exists() + return example_audio_dir + + +@pytest.fixture +def example_audio_files(example_audio_dir: Path) -> List[Path]: + audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]")) + assert len(audio_files) == 3 + return audio_files + + @pytest.fixture def data_dir() -> Path: dir = Path(__file__).parent / "data" diff --git a/tests/test_model.py b/tests/test_model.py index 7e5d997..3519c38 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,6 +1,8 @@ """Test suite for model functions.""" import warnings +from pathlib import Path +from typing import List import numpy as np from hypothesis import given, settings @@ -45,3 +47,32 @@ def test_can_import_model_without_pickle(duration: float): ) assert predictions_without_pickle == predictions_with_pickle + + +def test_can_import_model_without_pickle_on_test_data( + example_audio_files: List[Path], +): + # NOTE: remove this test once no other issues are found This is a temporary + # test to check that change in model loading did not impact model behaviour + # in any way. + + model_without_pickle, model_params_without_pickle = api.load_model( + weights_only=True + ) + model_with_pickle, model_params_with_pickle = api.load_model( + weights_only=False + ) + + assert model_params_without_pickle == model_params_with_pickle + + for audio_file in example_audio_files: + audio = api.load_audio(str(audio_file)) + predictions_without_pickle, _, _ = api.process_audio( + audio, + model=model_without_pickle, + ) + predictions_with_pickle, _, _ = api.process_audio( + audio, + model=model_with_pickle, + ) + assert predictions_without_pickle == predictions_with_pickle