Skip to content

Commit

Permalink
refactored segmentation_extractor to use baseextractormixin
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Sep 30, 2024
1 parent a1f0c74 commit e6d2c2c
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 55 deletions.
44 changes: 0 additions & 44 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,50 +272,6 @@ def get_summary_images(self, names: Optional[list[str]] = None) -> dict:
"""
pass

# TODO: Refactor _times methods from ImagingExtractor and SegmentationExtractor into a BaseExtractor class
def set_times(self, times: ArrayType):
"""Set the recording times in seconds for each frame.
Parameters
----------
times: array-like
The times in seconds for each frame
Notes
-----
Operates on _times attribute of the SegmentationExtractor object.
"""
assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!"
self._times = np.array(times, dtype=np.float64)

def has_time_vector(self) -> bool:
"""Detect if the SegmentationExtractor has a time vector set or not.
Returns
-------
has_time_vector: bool
True if the SegmentationExtractor has a time vector set, otherwise False.
"""
return self._times is not None

def frame_to_time(self, frames: Union[IntType, ArrayType]) -> Union[FloatType, ArrayType]:
"""Get the timing of frames in unit of seconds.
Parameters
----------
frames: int or array-like
The frame or frames to be converted to times
Returns
-------
times: float or array-like
The corresponding times in seconds
"""
if self._times is None:
return frames / self.get_sampling_frequency()
else:
return self._times[frames]

def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None):
"""Return a new ImagingExtractor ranging from the start_frame to the end_frame.
Expand Down
27 changes: 16 additions & 11 deletions tests/mixins/segmentation_extractor_mixin.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import pytest
import numpy as np

from .base_extractor_mixin import BaseExtractorMixin

class SegmentationExtractorMixin:
def test_get_image_size(self, segmentation_extractor, expected_image_masks):
image_size = segmentation_extractor.get_image_size()
assert image_size == (expected_image_masks.shape[0], expected_image_masks.shape[1])

def test_get_num_frames(self, segmentation_extractor, expected_roi_response_traces):
num_frames = segmentation_extractor.get_num_frames()
first_expected_roi_response_trace = list(expected_roi_response_traces.values())[0]
assert num_frames == first_expected_roi_response_trace.shape[0]
class SegmentationExtractorMixin(BaseExtractorMixin):
@pytest.fixture(scope="function")
def extractor(self, segmentation_extractor):
return segmentation_extractor

def test_get_sampling_frequency(self, segmentation_extractor, expected_sampling_frequency):
sampling_frequency = segmentation_extractor.get_sampling_frequency()
assert sampling_frequency == expected_sampling_frequency
@pytest.fixture(scope="function")
def extractor2(self, segmentation_extractor2):
return segmentation_extractor2

@pytest.fixture(scope="function")
def expected_image_size(self, expected_image_masks):
return expected_image_masks.shape[:2]

@pytest.fixture(scope="function")
def expected_num_frames(self, expected_roi_response_traces):
return list(expected_roi_response_traces.values())[0].shape[0]

def test_get_roi_ids(self, segmentation_extractor, expected_roi_ids):
roi_ids = segmentation_extractor.get_roi_ids()
Expand Down
79 changes: 79 additions & 0 deletions tests/test_minimal/test_numpy_segmentation_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,35 @@ def segmentation_extractor(
background_response_traces=expected_background_response_traces,
)

@pytest.fixture(scope="function")
def segmentation_extractor2(
self,
expected_image_masks,
expected_roi_response_traces,
expected_summary_images,
expected_roi_ids,
expected_roi_locations,
expected_accepted_list,
expected_rejected_list,
expected_background_response_traces,
expected_background_ids,
expected_background_image_masks,
expected_sampling_frequency,
):
return NumpySegmentationExtractor(
image_masks=expected_image_masks,
roi_response_traces=expected_roi_response_traces,
summary_images=expected_summary_images,
roi_ids=expected_roi_ids,
roi_locations=expected_roi_locations,
accepted_roi_ids=expected_accepted_list,
rejected_roi_ids=expected_rejected_list,
sampling_frequency=expected_sampling_frequency,
background_ids=expected_background_ids,
background_image_masks=expected_background_image_masks,
background_response_traces=expected_background_response_traces,
)


class TestNumpySegmentationExtractorFromFile(SegmentationExtractorMixin, FrameSliceSegmentationExtractorMixin):
@pytest.fixture(scope="function")
Expand Down Expand Up @@ -180,3 +209,53 @@ def segmentation_extractor(
sampling_frequency=expected_sampling_frequency,
background_ids=expected_background_ids,
)

@pytest.fixture(scope="function")
def segmentation_extractor2(
self,
expected_image_masks,
expected_roi_response_traces,
expected_summary_images,
expected_roi_ids,
expected_roi_locations,
expected_accepted_list,
expected_rejected_list,
expected_background_response_traces,
expected_background_ids,
expected_background_image_masks,
expected_sampling_frequency,
tmp_path,
):
name_to_ndarray = dict(
image_masks=expected_image_masks,
background_image_masks=expected_background_image_masks,
)
name_to_file_path = {}
for name, ndarray in name_to_ndarray.items():
file_path = tmp_path / f"{name}.npy"
file_path.parent.mkdir(parents=True, exist_ok=True)
np.save(file_path, ndarray)
name_to_file_path[name] = file_path
name_to_dict_of_ndarrays = dict(
roi_response_traces=expected_roi_response_traces,
background_response_traces=expected_background_response_traces,
summary_images=expected_summary_images,
)
name_to_dict_of_file_paths = {}
for name, dict_of_ndarrays in name_to_dict_of_ndarrays.items():
name_to_dict_of_file_paths[name] = {}
for key, ndarray in dict_of_ndarrays.items():
file_path = tmp_path / f"{name}_{key}.npy"
np.save(file_path, ndarray)
name_to_dict_of_file_paths[name][key] = file_path

return NumpySegmentationExtractor(
**name_to_file_path,
**name_to_dict_of_file_paths,
roi_ids=expected_roi_ids,
roi_locations=expected_roi_locations,
accepted_roi_ids=expected_accepted_list,
rejected_roi_ids=expected_rejected_list,
sampling_frequency=expected_sampling_frequency,
background_ids=expected_background_ids,
)

0 comments on commit e6d2c2c

Please sign in to comment.