diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index 7c21b96e40..948ac9f224 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -181,8 +181,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): batch_size=batch_size, desired_shape=self.desired_shape, ) - input_generator.randomize = False - input_generator._calculate_list_samples(input_generator.total_samples) di_output = self.model.predict(input_generator, workers=self.predict_workers, verbose=2) out_traces = input_generator.reshape_output(di_output) diff --git a/src/spikeinterface/preprocessing/deepinterpolation/generators.py b/src/spikeinterface/preprocessing/deepinterpolation/generators.py index a1625fe6cf..8200340ac1 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/generators.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/generators.py @@ -1,9 +1,10 @@ +from __future__ import annotations import tempfile import json -from typing import Any +from typing import Optional import numpy as np -from ...core import load_extractor +from ...core import load_extractor, concatenate_recordings, BaseRecording, BaseRecordingSegment from .tf_utils import has_tf, import_tf @@ -13,27 +14,57 @@ class SpikeInterfaceRecordingGenerator(SequentialGenerator): - """This generator is used when dealing with a SpikeInterface recording. - The desired shape controls the reshaping of the input data before convolutions.""" + """ + This generator is used when dealing with a SpikeInterface recording. + The desired shape controls the reshaping of the input data before convolutions. + """ def __init__( self, - recording, - pre_frame=30, - post_frame=30, - pre_post_omission=1, - desired_shape=(192, 2), - batch_size=100, - steps_per_epoch=10, - start_frame=None, - end_frame=None, - total_samples=-1, + recordings: BaseRecording | list[BaseRecording], + pre_frame: int = 30, + post_frame: int = 30, + pre_post_omission: int = 1, + desired_shape: tuple = (192, 2), + batch_size: int = 100, + steps_per_epoch: int = 10, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + total_samples: int = -1, ): - assert recording.get_num_segments() == 1, "Only supported for mono-segment recordings" - - self.recording = recording - num_samples = recording.get_num_samples() - self.total_samples = num_samples if total_samples == -1 else total_samples + if not isinstance(recordings, list): + recordings = [recordings] + self.recordings = recordings + if len(recordings) > 1: + assert ( + r.get_num_channels() == recordings[0].get_num_channels() for r in recordings[1:] + ), "All recordings must have the same number of channels" + + total_num_samples = np.sum(r.get_total_samples() for r in recordings) + # In case of multiple recordings and/or multiple segments, we calculate the frame periods to be excluded (borders) + exclude_intervals = [] + pre_extended = pre_frame + pre_post_omission + post_extended = post_frame + pre_post_omission + for i, recording in enumerate(recordings): + total_samples_pre = np.sum(r.get_total_samples() for r in recordings[:i]) + for segment_index in range(recording.get_num_segments()): + # exclude samples at the border of the recordings + num_samples_segment_pre = np.sum(recording.get_num_samples(s) for s in np.arange(segment_index)) + if num_samples_segment_pre > 0: + exclude_intervals.append( + ( + total_samples_pre + num_samples_segment_pre - pre_extended - 1, + total_samples_pre + num_samples_segment_pre + post_extended, + ) + ) + # exclude samples at the border of the recordings + if total_samples_pre > 0: + exclude_intervals.append((total_samples_pre - pre_extended - 1, total_samples_pre + post_extended)) + + total_valid_samples = ( + total_num_samples - np.sum([end - start for start, end in exclude_intervals]) - pre_extended - post_extended + ) + self.total_samples = int(total_valid_samples) if total_samples == -1 else total_samples assert len(desired_shape) == 2, "desired_shape should be 2D" assert ( desired_shape[0] * desired_shape[1] == recording.get_num_channels() @@ -41,8 +72,7 @@ def __init__( self.desired_shape = desired_shape start_frame = start_frame if start_frame is not None else 0 - end_frame = end_frame if end_frame is not None else recording.get_num_samples() - + end_frame = end_frame if end_frame is not None else total_num_samples assert end_frame > start_frame, "end_frame must be greater than start_frame" sequential_generator_params = dict() @@ -60,12 +90,15 @@ def __init__( json.dump(sequential_generator_params, f) super().__init__(json_path) - self._update_end_frame(num_samples) - self._calculate_list_samples(num_samples) - self.last_batch_size = np.mod(self.end_frame - self.start_frame, self.batch_size) + self._update_end_frame(total_num_samples) + + # self.list_samples will exclude the border intervals on the concat recording + self.recording_concat = concatenate_recordings(recordings) + self.exclude_intervals = exclude_intervals + self._calculate_list_samples(total_num_samples, exclude_intervals=exclude_intervals) self._kwargs = dict( - recording=recording, + recordings=recordings, pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=pre_post_omission, @@ -76,6 +109,45 @@ def __init__( end_frame=end_frame, ) + # this is overridden to exclude samples from borders + def _calculate_list_samples(self, total_frame_per_movie, exclude_intervals=[]): + # We first cut if start and end frames are too close to the edges. + self.start_sample = np.max([self.pre_frame + self.pre_post_omission, self.start_frame]) + self.end_sample = np.min( + [ + self.end_frame, + total_frame_per_movie - 1 - self.post_frame - self.pre_post_omission, + ] + ) + + if (self.end_sample - self.start_sample + 1) < self.batch_size: + raise Exception( + "Not enough frames to construct one " + + str(self.batch_size) + + " frame(s) batch between " + + str(self.start_sample) + + " and " + + str(self.end_sample) + + " frame number." + ) + + # +1 to make sure end_samples is included + list_samples_all = np.arange(self.start_sample, self.end_sample + 1) + + if len(exclude_intervals) > 0: + for start, end in exclude_intervals: + list_samples_all = list_samples_all[(list_samples_all <= start) | (list_samples_all >= end)] + self.list_samples = list_samples_all + else: + self.list_samples = list_samples_all + + if self.randomize: + np.random.shuffle(self.list_samples) + + # We cut the number of samples if asked to + if self.total_samples > 0 and self.total_samples < len(self.list_samples): + self.list_samples = self.list_samples[0 : self.total_samples] + def __getitem__(self, index): # This is to ensure we are going through # the entire data when steps_per_epoch start_frame, "end_frame must be greater than start_frame" + self.total_samples = end_frame - start_frame sequential_generator_params = dict() sequential_generator_params["steps_per_epoch"] = steps_per_epoch @@ -178,30 +248,10 @@ def __init__( json.dump(sequential_generator_params, f) super().__init__(json_path) - self._update_end_frame(self.total_samples) - self._calculate_list_samples(self.total_samples) - self.last_batch_size = np.mod(self.end_frame - self.start_frame, self.batch_size) - - def __len__(self): - "Denotes the total number of batches" - if self.last_batch_size == 0: - return int(len(self.list_samples) // self.batch_size) - else: - self.has_partial_batch = True - return int(len(self.list_samples) // self.batch_size) + 1 - - def generate_batch_indexes(self, index): - # This is to ensure we are going through - # the entire data when steps_per_epoch 0: - index = index + self.steps_per_epoch * self.epoch_index - # Generate indexes of the batch - indexes = slice(index * self.batch_size, (index + 1) * self.batch_size) - - if index == len(self) - 1 and self.last_batch_size > 0: - indexes = slice(-self.last_batch_size, len(self.list_samples)) - shuffle_indexes = self.list_samples[indexes] - return shuffle_indexes + self._update_end_frame(num_segment_samples) + # IMPORTANT: this is used for inference, so we don't want to shuffle + self.randomize = False + self._calculate_list_samples(num_segment_samples) def __getitem__(self, index): # This is to ensure we are going through diff --git a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py index 7fdbe13753..9624cffa9d 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py @@ -3,17 +3,20 @@ from pathlib import Path import probeinterface as pi -from spikeinterface import download_dataset, generate_recording +from spikeinterface import download_dataset, generate_recording, append_recordings, concatenate_recordings from spikeinterface.extractors import read_mearec, read_spikeglx, read_openephys from spikeinterface.preprocessing import depth_order, zscore from spikeinterface.preprocessing.deepinterpolation import train_deepinterpolation, deepinterpolate +from spikeinterface.preprocessing.deepinterpolation import train_deepinterpolation, deepinterpolate if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" + cache_folder = pytest.global_test_folder / "deepinterpolation" else: - cache_folder = Path("cache_folder") / "preprocessing" + cache_folder = Path("cache_folder") / "deepinterpolation" + if not cache_folder.is_dir(): + cache_folder.mkdir(parents=True) def recording_and_shape(): @@ -34,6 +37,27 @@ def recording_and_shape_fixture(): return recording_and_shape() +def test_deepinterpolation_generator_borders(recording_and_shape_fixture): + """Test that the generator avoids borders in multi-segment and recording lists cases""" + from spikeinterface.preprocessing.deepinterpolation.generators import SpikeInterfaceRecordingGenerator + + recording, desired_shape = recording_and_shape_fixture + recording_multi_segment = append_recordings([recording, recording, recording]) + recording_list = [recording, recording, recording] + recording_multi_list = [recording_multi_segment, recording_multi_segment, recording_multi_segment] + + gen_multi_segment = SpikeInterfaceRecordingGenerator(recording_multi_segment, desired_shape=desired_shape) + gen_list = SpikeInterfaceRecordingGenerator(recording_list, desired_shape=desired_shape) + gen_multi_list = SpikeInterfaceRecordingGenerator(recording_multi_list, desired_shape=desired_shape) + + # exclude in between segments + assert len(gen_multi_segment.exclude_intervals) == 2 + # exclude in between recordings + assert len(gen_list.exclude_intervals) == 2 + # exclude in between recordings and segments + assert len(gen_multi_list.exclude_intervals) == 2 * len(recording_multi_list) + 2 + + def test_deepinterpolation_training(recording_and_shape_fixture): recording, desired_shape = recording_and_shape_fixture @@ -45,10 +69,10 @@ def test_deepinterpolation_training(recording_and_shape_fixture): model_name="training", train_start_s=1, train_end_s=10, - train_duration_s=0.1, + train_duration_s=0.02, test_start_s=0, test_end_s=1, - test_duration_s=0.05, + test_duration_s=0.01, pre_frame=20, post_frame=20, run_uid="si_test", @@ -73,10 +97,10 @@ def test_deepinterpolation_transfer(recording_and_shape_fixture, tmp_path): existing_model_path=existing_model_path, train_start_s=1, train_end_s=10, - train_duration_s=0.1, + train_duration_s=0.02, test_start_s=0, test_end_s=1, - test_duration_s=0.05, + test_duration_s=0.01, pre_frame=20, post_frame=20, pre_post_omission=1, @@ -94,6 +118,7 @@ def test_deepinterpolation_inference(recording_and_shape_fixture): recording_di = deepinterpolate( recording, model_path=existing_model_path, pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=1 ) + print(recording_di) traces_original_first = recording.get_traces(start_frame=0, end_frame=100) traces_di_first = recording_di.get_traces(start_frame=0, end_frame=100) assert traces_di_first.shape == (100, recording.get_num_channels()) @@ -123,9 +148,19 @@ def test_deepinterpolation_inference_multi_job(recording_and_shape_fixture): pre_post_omission=1, use_gpu=False, ) + print(recording_di) recording_di_slice = recording_di.frame_slice(start_frame=0, end_frame=int(0.5 * recording.sampling_frequency)) recording_di_slice.save(folder=Path(cache_folder) / "di_slice", n_jobs=2, chunk_duration="50ms") traces_chunks = recording_di_slice.get_traces() traces_full = recording_di_slice.get_traces() np.testing.assert_array_equal(traces_chunks, traces_full) + + +if __name__ == "__main__": + recording_shape = recording_and_shape() + # test_deepinterpolation_training(recording_shape) + # test_deepinterpolation_transfer() + test_deepinterpolation_inference(recording_shape) + # test_deepinterpolation_inference_multi_job() + # test_deepinterpolation_generator_borders(recording_shape) diff --git a/src/spikeinterface/preprocessing/deepinterpolation/train.py b/src/spikeinterface/preprocessing/deepinterpolation/train.py index 5865153393..06813d10e2 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/train.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/train.py @@ -1,6 +1,7 @@ from __future__ import annotations import json import os +import warnings from pathlib import Path from typing import Optional @@ -18,16 +19,17 @@ def train_deepinterpolation( - recording: BaseRecording, + recordings: BaseRecording | list[BaseRecording], model_folder: str | Path, model_name: str, - train_start_s: float, - train_end_s: float, - test_start_s: float, - test_end_s: float, desired_shape: tuple[int, int], + train_start_s: Optional[float] = None, + train_end_s: Optional[float] = None, train_duration_s: Optional[float] = None, + test_start_s: Optional[float] = None, + test_end_s: Optional[float] = None, test_duration_s: Optional[float] = None, + test_recordings: Optional[BaseRecording | list[BaseRecording]] = None, pre_frame: int = 30, post_frame: int = 30, pre_post_omission: int = 1, @@ -53,24 +55,27 @@ def train_deepinterpolation( Parameters ---------- - recording : RecordingExtractor - The recording extractor to be deepinteprolated + recordings : BaseRecording | list[BaseRecording] + The recording(s) to be deepinteprolated. If a list is given, the recordings are concatenated + and samples at the border of the recordings are omitted. model_folder : str | Path Path to the folder where the model will be saved model_name : str Name of the model to be used - train_start_s : float - Start time of the training in seconds - train_end_s : float - End time of the training in seconds - train_duration_s : float - Duration of the training in seconds - test_start_s : float - Start time of the testing in seconds - test_end_s : float - End time of the testing in seconds - test_duration_s : float - Duration of the testing in seconds + train_start_s : float or None, default: None + Start time of the training in seconds. If None, the training starts at the beginning of the recording + train_end_s : float or None, default: None + End time of the training in seconds. If None, the training ends at the end of the recording + train_duration_s : float, default: None + Duration of the training in seconds. If None, the entire [train_start_s, train_end_s] is used for training + test_start_s : float or None, default: None + Start time of the testing in seconds. If None, the testing starts at the beginning of the recording + test_end_s : float or None, default: None + End time of the testing in seconds. If None, the testing ends at the end of the recording + test_duration_s : float or None, default: None + Duration of the testing in seconds, If None, the entire [test_start_s, test_end_s] is used for testing (not recommended) + test_recordings : BaseRecording | list[BaseRecording] | None, default: None + The recording(s) used for testing. If None, the training recording (or recordings) is used for testing desired_shape : tuple Shape of the input to the network pre_frame : int @@ -120,19 +125,20 @@ def train_deepinterpolation( nb_workers = os.cpu_count() args = ( - recording, + recordings, model_folder, model_name, + desired_shape, train_start_s, train_end_s, train_duration_s, test_start_s, test_end_s, test_duration_s, + test_recordings, pre_frame, post_frame, pre_post_omission, - desired_shape, existing_model_path, verbose, nb_gpus, @@ -159,19 +165,20 @@ def train_deepinterpolation( def train_deepinterpolation_process( - recording: BaseRecording, + recordings: BaseRecording | list[BaseRecording], model_folder: str | Path, model_name: str, + desired_shape: tuple[int, int], train_start_s: float, train_end_s: float, train_duration_s: float | None, test_start_s: float, test_end_s: float, test_duration_s: float | None, - pre_frame: int, - post_frame: int, - pre_post_omission: int, - desired_shape: tuple[int, int], + test_recordings: Optional[BaseRecording | list[BaseRecording]] = None, + pre_frame: int = 30, + post_frame: int = 30, + pre_post_omission: int = 1, existing_model_path: Optional[str | Path] = None, verbose: bool = True, nb_gpus: int = 1, @@ -200,7 +207,9 @@ def train_deepinterpolation_process( trained_model_folder.mkdir(exist_ok=True) # check if roughly zscored - fs = recording.sampling_frequency + if not isinstance(recordings, list): + recordings = [recordings] + fs = recordings[0].sampling_frequency ### Define params start_frame_training = int(train_start_s * fs) @@ -216,6 +225,13 @@ def train_deepinterpolation_process( else: total_samples_testing = -1 + if test_recordings is None: + test_recordings = recordings + if (start_frame_training <= start_frame_test < end_frame_training) or ( + start_frame_training < end_frame_test <= end_frame_training + ): + warnings.warn("Training and testing overlap. This is not recommended.") + # Those are parameters used for the network topology network_params = dict() network_params["type"] = "network" @@ -243,7 +259,7 @@ def train_deepinterpolation_process( # Training (from core_trainor class) training_data_generator = SpikeInterfaceRecordingGenerator( - recording, + recordings, pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=pre_post_omission, @@ -253,7 +269,7 @@ def train_deepinterpolation_process( total_samples=total_samples_training, ) test_data_generator = SpikeInterfaceRecordingGenerator( - recording, + test_recordings, pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=pre_post_omission,