Skip to content

Commit

Permalink
refactored imagingextractor to use baseimagingextractormixin
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Sep 30, 2024
1 parent 53baa5d commit a1f0c74
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 65 deletions.
71 changes: 71 additions & 0 deletions tests/mixins/base_extractor_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest
import numpy as np


class BaseExtractorMixin:
def test_get_image_size(self, extractor, expected_image_size):
image_size = extractor.get_image_size()
assert image_size == expected_image_size

def test_get_num_frames(self, extractor, expected_num_frames):
num_frames = extractor.get_num_frames()
assert num_frames == expected_num_frames

def test_get_sampling_frequency(self, extractor, expected_sampling_frequency):
sampling_frequency = extractor.get_sampling_frequency()
assert sampling_frequency == expected_sampling_frequency

@pytest.mark.parametrize("sampling_frequency", [1, 2, 3])
def test_frame_to_time_no_times(self, extractor, sampling_frequency):
extractor._times = None
extractor._sampling_frequency = sampling_frequency
times = extractor.frame_to_time(frames=[0, 1])
expected_times = np.array([0, 1]) / sampling_frequency
assert np.array_equal(times, expected_times)

def test_frame_to_time_with_times(self, extractor):
expected_times = np.array([0, 1])
extractor._times = expected_times
times = extractor.frame_to_time(frames=[0, 1])

assert np.array_equal(times, expected_times)

@pytest.mark.parametrize("sampling_frequency", [1, 2, 3])
def test_time_to_frame_no_times(self, extractor, sampling_frequency):
extractor._times = None
extractor._sampling_frequency = sampling_frequency
times = np.array([0, 1]) / sampling_frequency
frames = extractor.time_to_frame(times=times)
expected_frames = np.array([0, 1])
assert np.array_equal(frames, expected_frames)

def test_time_to_frame_with_times(self, extractor):
extractor._times = np.array([0, 1])
times = np.array([0, 1])
frames = extractor.time_to_frame(times=times)
expected_frames = np.array([0, 1])
assert np.array_equal(frames, expected_frames)

def test_set_times(self, extractor):
times = np.arange(extractor.get_num_frames())
extractor.set_times(times)
assert np.array_equal(extractor._times, times)

def test_set_times_invalid_length(self, extractor):
with pytest.raises(AssertionError):
extractor.set_times(np.arange(extractor.get_num_frames() + 1))

@pytest.mark.parametrize("times", [None, np.array([0, 1])])
def test_has_time_vector(self, times, extractor):
extractor._times = times
if times is None:
assert not extractor.has_time_vector()
else:
assert extractor.has_time_vector()

def test_copy_times(self, extractor, extractor2):
expected_times = np.arange(extractor.get_num_frames())
extractor._times = expected_times
extractor2.copy_times(extractor)
assert np.array_equal(extractor2._times, expected_times)
assert extractor2._times is not expected_times
80 changes: 15 additions & 65 deletions tests/mixins/imaging_extractor_mixin.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import pytest
import numpy as np
from .base_extractor_mixin import BaseExtractorMixin


class ImagingExtractorMixin:
def test_get_image_size(self, imaging_extractor, expected_video):
image_size = imaging_extractor.get_image_size()
assert image_size == (expected_video.shape[1], expected_video.shape[2])
class ImagingExtractorMixin(BaseExtractorMixin):
@pytest.fixture(scope="function")
def extractor(self, imaging_extractor):
return imaging_extractor

def test_get_num_frames(self, imaging_extractor, expected_video):
num_frames = imaging_extractor.get_num_frames()
assert num_frames == expected_video.shape[0]
@pytest.fixture(scope="function")
def extractor2(self, imaging_extractor2):
return imaging_extractor2

def test_get_sampling_frequency(self, imaging_extractor, expected_sampling_frequency):
sampling_frequency = imaging_extractor.get_sampling_frequency()
assert sampling_frequency == expected_sampling_frequency
@pytest.fixture(scope="function")
def expected_image_size(self, expected_video):
return expected_video.shape[1], expected_video.shape[2]

@pytest.fixture(scope="function")
def expected_num_frames(self, expected_video):
return expected_video.shape[0]

def test_get_dtype(self, imaging_extractor, expected_video):
dtype = imaging_extractor.get_dtype()
Expand Down Expand Up @@ -60,61 +65,6 @@ def test_get_frames_invalid_frame_idxs(self, imaging_extractor):
with pytest.raises(AssertionError):
imaging_extractor.get_frames(frame_idxs=[0.5])

@pytest.mark.parametrize("sampling_frequency", [1, 2, 3])
def test_frame_to_time_no_times(self, imaging_extractor, sampling_frequency):
imaging_extractor._times = None
imaging_extractor._sampling_frequency = sampling_frequency
times = imaging_extractor.frame_to_time(frames=[0, 1])
expected_times = np.array([0, 1]) / sampling_frequency
assert np.array_equal(times, expected_times)

def test_frame_to_time_with_times(self, imaging_extractor):
expected_times = np.array([0, 1])
imaging_extractor._times = expected_times
times = imaging_extractor.frame_to_time(frames=[0, 1])

assert np.array_equal(times, expected_times)

@pytest.mark.parametrize("sampling_frequency", [1, 2, 3])
def test_time_to_frame_no_times(self, imaging_extractor, sampling_frequency):
imaging_extractor._times = None
imaging_extractor._sampling_frequency = sampling_frequency
times = np.array([0, 1]) / sampling_frequency
frames = imaging_extractor.time_to_frame(times=times)
expected_frames = np.array([0, 1])
assert np.array_equal(frames, expected_frames)

def test_time_to_frame_with_times(self, imaging_extractor):
imaging_extractor._times = np.array([0, 1])
times = np.array([0, 1])
frames = imaging_extractor.time_to_frame(times=times)
expected_frames = np.array([0, 1])
assert np.array_equal(frames, expected_frames)

def test_set_times(self, imaging_extractor):
times = np.arange(imaging_extractor.get_num_frames())
imaging_extractor.set_times(times)
assert np.array_equal(imaging_extractor._times, times)

def test_set_times_invalid_length(self, imaging_extractor):
with pytest.raises(AssertionError):
imaging_extractor.set_times(np.arange(imaging_extractor.get_num_frames() + 1))

@pytest.mark.parametrize("times", [None, np.array([0, 1])])
def test_has_time_vector(self, times, imaging_extractor):
imaging_extractor._times = times
if times is None:
assert not imaging_extractor.has_time_vector()
else:
assert imaging_extractor.has_time_vector()

def test_copy_times(self, imaging_extractor, imaging_extractor2):
expected_times = np.arange(imaging_extractor.get_num_frames())
imaging_extractor._times = expected_times
imaging_extractor2.copy_times(imaging_extractor)
assert np.array_equal(imaging_extractor2._times, expected_times)
assert imaging_extractor2._times is not expected_times

def test_eq(self, imaging_extractor, imaging_extractor2):
assert imaging_extractor == imaging_extractor2

Expand Down

0 comments on commit a1f0c74

Please sign in to comment.