From a1f0c74bcb3af96e02f7d0816df7f750e177c872 Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Mon, 30 Sep 2024 10:11:05 -0700 Subject: [PATCH] refactored imagingextractor to use baseimagingextractormixin --- tests/mixins/base_extractor_mixin.py | 71 ++++++++++++++++++++++ tests/mixins/imaging_extractor_mixin.py | 80 +++++-------------------- 2 files changed, 86 insertions(+), 65 deletions(-) create mode 100644 tests/mixins/base_extractor_mixin.py diff --git a/tests/mixins/base_extractor_mixin.py b/tests/mixins/base_extractor_mixin.py new file mode 100644 index 00000000..9c988214 --- /dev/null +++ b/tests/mixins/base_extractor_mixin.py @@ -0,0 +1,71 @@ +import pytest +import numpy as np + + +class BaseExtractorMixin: + def test_get_image_size(self, extractor, expected_image_size): + image_size = extractor.get_image_size() + assert image_size == expected_image_size + + def test_get_num_frames(self, extractor, expected_num_frames): + num_frames = extractor.get_num_frames() + assert num_frames == expected_num_frames + + def test_get_sampling_frequency(self, extractor, expected_sampling_frequency): + sampling_frequency = extractor.get_sampling_frequency() + assert sampling_frequency == expected_sampling_frequency + + @pytest.mark.parametrize("sampling_frequency", [1, 2, 3]) + def test_frame_to_time_no_times(self, extractor, sampling_frequency): + extractor._times = None + extractor._sampling_frequency = sampling_frequency + times = extractor.frame_to_time(frames=[0, 1]) + expected_times = np.array([0, 1]) / sampling_frequency + assert np.array_equal(times, expected_times) + + def test_frame_to_time_with_times(self, extractor): + expected_times = np.array([0, 1]) + extractor._times = expected_times + times = extractor.frame_to_time(frames=[0, 1]) + + assert np.array_equal(times, expected_times) + + @pytest.mark.parametrize("sampling_frequency", [1, 2, 3]) + def test_time_to_frame_no_times(self, extractor, sampling_frequency): + extractor._times = None + extractor._sampling_frequency = sampling_frequency + times = np.array([0, 1]) / sampling_frequency + frames = extractor.time_to_frame(times=times) + expected_frames = np.array([0, 1]) + assert np.array_equal(frames, expected_frames) + + def test_time_to_frame_with_times(self, extractor): + extractor._times = np.array([0, 1]) + times = np.array([0, 1]) + frames = extractor.time_to_frame(times=times) + expected_frames = np.array([0, 1]) + assert np.array_equal(frames, expected_frames) + + def test_set_times(self, extractor): + times = np.arange(extractor.get_num_frames()) + extractor.set_times(times) + assert np.array_equal(extractor._times, times) + + def test_set_times_invalid_length(self, extractor): + with pytest.raises(AssertionError): + extractor.set_times(np.arange(extractor.get_num_frames() + 1)) + + @pytest.mark.parametrize("times", [None, np.array([0, 1])]) + def test_has_time_vector(self, times, extractor): + extractor._times = times + if times is None: + assert not extractor.has_time_vector() + else: + assert extractor.has_time_vector() + + def test_copy_times(self, extractor, extractor2): + expected_times = np.arange(extractor.get_num_frames()) + extractor._times = expected_times + extractor2.copy_times(extractor) + assert np.array_equal(extractor2._times, expected_times) + assert extractor2._times is not expected_times diff --git a/tests/mixins/imaging_extractor_mixin.py b/tests/mixins/imaging_extractor_mixin.py index fdcce09b..aa2c2c59 100644 --- a/tests/mixins/imaging_extractor_mixin.py +++ b/tests/mixins/imaging_extractor_mixin.py @@ -1,19 +1,24 @@ import pytest import numpy as np +from .base_extractor_mixin import BaseExtractorMixin -class ImagingExtractorMixin: - def test_get_image_size(self, imaging_extractor, expected_video): - image_size = imaging_extractor.get_image_size() - assert image_size == (expected_video.shape[1], expected_video.shape[2]) +class ImagingExtractorMixin(BaseExtractorMixin): + @pytest.fixture(scope="function") + def extractor(self, imaging_extractor): + return imaging_extractor - def test_get_num_frames(self, imaging_extractor, expected_video): - num_frames = imaging_extractor.get_num_frames() - assert num_frames == expected_video.shape[0] + @pytest.fixture(scope="function") + def extractor2(self, imaging_extractor2): + return imaging_extractor2 - def test_get_sampling_frequency(self, imaging_extractor, expected_sampling_frequency): - sampling_frequency = imaging_extractor.get_sampling_frequency() - assert sampling_frequency == expected_sampling_frequency + @pytest.fixture(scope="function") + def expected_image_size(self, expected_video): + return expected_video.shape[1], expected_video.shape[2] + + @pytest.fixture(scope="function") + def expected_num_frames(self, expected_video): + return expected_video.shape[0] def test_get_dtype(self, imaging_extractor, expected_video): dtype = imaging_extractor.get_dtype() @@ -60,61 +65,6 @@ def test_get_frames_invalid_frame_idxs(self, imaging_extractor): with pytest.raises(AssertionError): imaging_extractor.get_frames(frame_idxs=[0.5]) - @pytest.mark.parametrize("sampling_frequency", [1, 2, 3]) - def test_frame_to_time_no_times(self, imaging_extractor, sampling_frequency): - imaging_extractor._times = None - imaging_extractor._sampling_frequency = sampling_frequency - times = imaging_extractor.frame_to_time(frames=[0, 1]) - expected_times = np.array([0, 1]) / sampling_frequency - assert np.array_equal(times, expected_times) - - def test_frame_to_time_with_times(self, imaging_extractor): - expected_times = np.array([0, 1]) - imaging_extractor._times = expected_times - times = imaging_extractor.frame_to_time(frames=[0, 1]) - - assert np.array_equal(times, expected_times) - - @pytest.mark.parametrize("sampling_frequency", [1, 2, 3]) - def test_time_to_frame_no_times(self, imaging_extractor, sampling_frequency): - imaging_extractor._times = None - imaging_extractor._sampling_frequency = sampling_frequency - times = np.array([0, 1]) / sampling_frequency - frames = imaging_extractor.time_to_frame(times=times) - expected_frames = np.array([0, 1]) - assert np.array_equal(frames, expected_frames) - - def test_time_to_frame_with_times(self, imaging_extractor): - imaging_extractor._times = np.array([0, 1]) - times = np.array([0, 1]) - frames = imaging_extractor.time_to_frame(times=times) - expected_frames = np.array([0, 1]) - assert np.array_equal(frames, expected_frames) - - def test_set_times(self, imaging_extractor): - times = np.arange(imaging_extractor.get_num_frames()) - imaging_extractor.set_times(times) - assert np.array_equal(imaging_extractor._times, times) - - def test_set_times_invalid_length(self, imaging_extractor): - with pytest.raises(AssertionError): - imaging_extractor.set_times(np.arange(imaging_extractor.get_num_frames() + 1)) - - @pytest.mark.parametrize("times", [None, np.array([0, 1])]) - def test_has_time_vector(self, times, imaging_extractor): - imaging_extractor._times = times - if times is None: - assert not imaging_extractor.has_time_vector() - else: - assert imaging_extractor.has_time_vector() - - def test_copy_times(self, imaging_extractor, imaging_extractor2): - expected_times = np.arange(imaging_extractor.get_num_frames()) - imaging_extractor._times = expected_times - imaging_extractor2.copy_times(imaging_extractor) - assert np.array_equal(imaging_extractor2._times, expected_times) - assert imaging_extractor2._times is not expected_times - def test_eq(self, imaging_extractor, imaging_extractor2): assert imaging_extractor == imaging_extractor2