diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index 2989f9c3..6ae3fab1 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -272,50 +272,6 @@ def get_summary_images(self, names: Optional[list[str]] = None) -> dict: """ pass - # TODO: Refactor _times methods from ImagingExtractor and SegmentationExtractor into a BaseExtractor class - def set_times(self, times: ArrayType): - """Set the recording times in seconds for each frame. - - Parameters - ---------- - times: array-like - The times in seconds for each frame - - Notes - ----- - Operates on _times attribute of the SegmentationExtractor object. - """ - assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!" - self._times = np.array(times, dtype=np.float64) - - def has_time_vector(self) -> bool: - """Detect if the SegmentationExtractor has a time vector set or not. - - Returns - ------- - has_time_vector: bool - True if the SegmentationExtractor has a time vector set, otherwise False. - """ - return self._times is not None - - def frame_to_time(self, frames: Union[IntType, ArrayType]) -> Union[FloatType, ArrayType]: - """Get the timing of frames in unit of seconds. - - Parameters - ---------- - frames: int or array-like - The frame or frames to be converted to times - - Returns - ------- - times: float or array-like - The corresponding times in seconds - """ - if self._times is None: - return frames / self.get_sampling_frequency() - else: - return self._times[frames] - def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None): """Return a new ImagingExtractor ranging from the start_frame to the end_frame. diff --git a/tests/mixins/segmentation_extractor_mixin.py b/tests/mixins/segmentation_extractor_mixin.py index cedec02f..055b32a4 100644 --- a/tests/mixins/segmentation_extractor_mixin.py +++ b/tests/mixins/segmentation_extractor_mixin.py @@ -1,20 +1,25 @@ import pytest import numpy as np +from .base_extractor_mixin import BaseExtractorMixin -class SegmentationExtractorMixin: - def test_get_image_size(self, segmentation_extractor, expected_image_masks): - image_size = segmentation_extractor.get_image_size() - assert image_size == (expected_image_masks.shape[0], expected_image_masks.shape[1]) - def test_get_num_frames(self, segmentation_extractor, expected_roi_response_traces): - num_frames = segmentation_extractor.get_num_frames() - first_expected_roi_response_trace = list(expected_roi_response_traces.values())[0] - assert num_frames == first_expected_roi_response_trace.shape[0] +class SegmentationExtractorMixin(BaseExtractorMixin): + @pytest.fixture(scope="function") + def extractor(self, segmentation_extractor): + return segmentation_extractor - def test_get_sampling_frequency(self, segmentation_extractor, expected_sampling_frequency): - sampling_frequency = segmentation_extractor.get_sampling_frequency() - assert sampling_frequency == expected_sampling_frequency + @pytest.fixture(scope="function") + def extractor2(self, segmentation_extractor2): + return segmentation_extractor2 + + @pytest.fixture(scope="function") + def expected_image_size(self, expected_image_masks): + return expected_image_masks.shape[:2] + + @pytest.fixture(scope="function") + def expected_num_frames(self, expected_roi_response_traces): + return list(expected_roi_response_traces.values())[0].shape[0] def test_get_roi_ids(self, segmentation_extractor, expected_roi_ids): roi_ids = segmentation_extractor.get_roi_ids() diff --git a/tests/test_minimal/test_numpy_segmentation_extractor.py b/tests/test_minimal/test_numpy_segmentation_extractor.py index 2d9ee984..0b92e77e 100644 --- a/tests/test_minimal/test_numpy_segmentation_extractor.py +++ b/tests/test_minimal/test_numpy_segmentation_extractor.py @@ -129,6 +129,35 @@ def segmentation_extractor( background_response_traces=expected_background_response_traces, ) + @pytest.fixture(scope="function") + def segmentation_extractor2( + self, + expected_image_masks, + expected_roi_response_traces, + expected_summary_images, + expected_roi_ids, + expected_roi_locations, + expected_accepted_list, + expected_rejected_list, + expected_background_response_traces, + expected_background_ids, + expected_background_image_masks, + expected_sampling_frequency, + ): + return NumpySegmentationExtractor( + image_masks=expected_image_masks, + roi_response_traces=expected_roi_response_traces, + summary_images=expected_summary_images, + roi_ids=expected_roi_ids, + roi_locations=expected_roi_locations, + accepted_roi_ids=expected_accepted_list, + rejected_roi_ids=expected_rejected_list, + sampling_frequency=expected_sampling_frequency, + background_ids=expected_background_ids, + background_image_masks=expected_background_image_masks, + background_response_traces=expected_background_response_traces, + ) + class TestNumpySegmentationExtractorFromFile(SegmentationExtractorMixin, FrameSliceSegmentationExtractorMixin): @pytest.fixture(scope="function") @@ -180,3 +209,53 @@ def segmentation_extractor( sampling_frequency=expected_sampling_frequency, background_ids=expected_background_ids, ) + + @pytest.fixture(scope="function") + def segmentation_extractor2( + self, + expected_image_masks, + expected_roi_response_traces, + expected_summary_images, + expected_roi_ids, + expected_roi_locations, + expected_accepted_list, + expected_rejected_list, + expected_background_response_traces, + expected_background_ids, + expected_background_image_masks, + expected_sampling_frequency, + tmp_path, + ): + name_to_ndarray = dict( + image_masks=expected_image_masks, + background_image_masks=expected_background_image_masks, + ) + name_to_file_path = {} + for name, ndarray in name_to_ndarray.items(): + file_path = tmp_path / f"{name}.npy" + file_path.parent.mkdir(parents=True, exist_ok=True) + np.save(file_path, ndarray) + name_to_file_path[name] = file_path + name_to_dict_of_ndarrays = dict( + roi_response_traces=expected_roi_response_traces, + background_response_traces=expected_background_response_traces, + summary_images=expected_summary_images, + ) + name_to_dict_of_file_paths = {} + for name, dict_of_ndarrays in name_to_dict_of_ndarrays.items(): + name_to_dict_of_file_paths[name] = {} + for key, ndarray in dict_of_ndarrays.items(): + file_path = tmp_path / f"{name}_{key}.npy" + np.save(file_path, ndarray) + name_to_dict_of_file_paths[name][key] = file_path + + return NumpySegmentationExtractor( + **name_to_file_path, + **name_to_dict_of_file_paths, + roi_ids=expected_roi_ids, + roi_locations=expected_roi_locations, + accepted_roi_ids=expected_accepted_list, + rejected_roi_ids=expected_rejected_list, + sampling_frequency=expected_sampling_frequency, + background_ids=expected_background_ids, + )