diff --git a/.github/workflows/deepinterpolation.yml b/.github/workflows/deepinterpolation.yml new file mode 100644 index 0000000000..7f5e3e54f9 --- /dev/null +++ b/.github/workflows/deepinterpolation.yml @@ -0,0 +1,51 @@ +name: Testing deepinterpolation + +on: + pull_request: + types: [synchronize, opened, reopened] + branches: + - main + +concurrency: # Cancel previous workflows on the same pull request + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-and-test: + name: Test on ${{ matrix.os }} OS + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.9' + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v35 + - name: Deepinteprolation changes + id: modules-changed + run: | + for file in ${{ steps.changed-files.outputs.all_changed_files }}; do + if [[ $file == *"/deepinterpolation/"* ]]; then + echo "DeepInterpolation changed" + echo "DEEPINTERPOLATION_CHANGED=true" >> $GITHUB_OUTPUT + fi + done + - name: Install dependencies + if: ${{ steps.modules-changed.outputs.DEEPINTERPOLATION_CHANGED == 'true' }} + run: | + python -m pip install -U pip # Official recommended way + # install deepinteprolation + pip install tensorflow==2.7.0 + pip install deepinterpolation@git+https://github.com/AllenInstitute/deepinterpolation.git + pip install protobuf==3.20.* + pip install -e .[full,test_core] + - name: Test DeepInterpolation with pytest + if: ${{ steps.modules-changed.outputs.DEEPINTERPOLATION_CHANGED == 'true' }} + run: | + pytest -v src/spikeinterface/preprocessing/deepinterpolation + shell: bash # Necessary for pipeline to work on windows diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index 8f88e84039..dad42e021b 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -136,7 +136,7 @@ jobs: run: ./.github/run_tests.sh "extractors and not streaming_extractors" - name: Test preprocessing if: ${{ steps.modules-changed.outputs.PREPROCESSING_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} - run: ./.github/run_tests.sh preprocessing + run: ./.github/run_tests.sh "preprocessing and not deepinterpolation" - name: Test postprocessing if: ${{ steps.modules-changed.outputs.POSTPROCESSING_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} run: ./.github/run_tests.sh postprocessing diff --git a/src/spikeinterface/preprocessing/deepinterpolation/__init__.py b/src/spikeinterface/preprocessing/deepinterpolation/__init__.py index 2ff166a81e..13d9c69def 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/__init__.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/__init__.py @@ -1 +1,2 @@ from .deepinterpolation import DeepInterpolatedRecording, deepinterpolate +from .train import train_deepinterpolation diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index f82ea8a5df..948ac9f224 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -1,6 +1,7 @@ import numpy as np -import os +from typing import Optional +from .tf_utils import has_tf, import_tf from ...core import BaseRecording from ...core.core_tools import define_function_from_class from ..basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -8,305 +9,111 @@ from spikeinterface.core import get_random_data_chunks -def import_tf(use_gpu=True, disable_tf_logger=True): - import tensorflow as tf - - if not use_gpu: - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" - - if disable_tf_logger: - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - tf.get_logger().setLevel("ERROR") - - tf.compat.v1.disable_eager_execution() - gpus = tf.config.list_physical_devices("GPU") - if gpus: - try: - # Currently, memory growth needs to be the same across GPUs - for gpu in gpus: - tf.config.experimental.set_memory_growth(gpu, True) - except RuntimeError as e: - # Memory growth must be set before GPUs have been initialized - print(e) - return tf - - -def has_tf(use_gpu=True, disable_tf_logger=True): - try: - import_tf(use_gpu, disable_tf_logger) - return True - except ImportError: - return False +class DeepInterpolatedRecording(BasePreprocessor): + """ + DeepInterpolatedRecording is a wrapper around a recording extractor that allows to apply a deepinterpolation model. + For more information, see: + Lecoq et al. (2021) Removing independent noise in systems neuroscience data using DeepInterpolation. + Nature Methods. 18: 1401-1408. doi: 10.1038/s41592-021-01285-2. -def define_input_generator_class(use_gpu, disable_tf_logger=True): - """Define DeepInterpolationInputGenerator class at run-time + Parts of this code is adapted from https://github.com/AllenInstitute/deepinterpolation. Parameters ---------- - use_gpu : bool - Whether to load TF with GPU capabilities - disable_tf_logger : bool, optional - If True, tensorflow logging is disabled, by default True + recording: BaseRecording + The recording extractor to be deepinteprolated + model_path: str + Path to the deepinterpolation h5 model + pre_frame: int + Number of frames before the frame to be predicted + post_frame: int + Number of frames after the frame to be predicted + pre_post_omission: int + Number of frames to be omitted before and after the frame to be predicted + batch_size: int + Batch size to be used for the prediction + predict_workers: int + Number of workers to be used for the tensorflow `predict` function. + Multiple workers can be used to speed up the prediction by pre-fetching the data. + use_gpu: bool + If True, the gpu will be used for the prediction + disable_tf_logger: bool + If True, the tensorflow logger will be disabled + memory_gpu: int + The amount of memory to be used by the gpu Returns ------- - class - The defined DeepInterpolationInputGenerator class + recording: DeepInterpolatedRecording + The deepinterpolated recording extractor object """ - tf = import_tf(use_gpu, disable_tf_logger) - - class DeepInterpolationInputGenerator(tf.keras.utils.Sequence): - def __init__( - self, - recording, - start_frame, - end_frame, - batch_size, - pre_frames, - post_frames, - pre_post_omission, - local_mean, - local_std, - ): - self.recording = recording - self.start_frame = start_frame - self.end_frame = end_frame - - self.batch_size = batch_size - self.last_batch_size = (end_frame - start_frame) - (self.__len__() - 1) * batch_size - - self.pre_frames = pre_frames - self.post_frames = post_frames - self.pre_post_omission = pre_post_omission - - self.local_mean = local_mean - self.local_std = local_std - - def __len__(self): - return -((self.end_frame - self.start_frame) // -self.batch_size) - - def __getitem__(self, idx): - n_batches = self.__len__() - if idx < n_batches - 1: - start_frame = self.start_frame + self.batch_size * idx - self.pre_frames - self.pre_post_omission - end_frame = self.start_frame + self.batch_size * (idx + 1) + self.post_frames + self.pre_post_omission - traces = self.recording.get_traces( - start_frame=start_frame, end_frame=end_frame, channel_indices=slice(None) - ) - batch_size = self.batch_size - else: - start_frame = self.end_frame - self.last_batch_size - self.pre_frames - self.pre_post_omission - end_frame = self.end_frame + self.post_frames + self.pre_post_omission - traces = self.recording.get_traces( - start_frame=start_frame, end_frame=end_frame, channel_indices=slice(None) - ) - batch_size = self.last_batch_size - - shape = (traces.shape[0], int(384 / 2), 2) - traces = np.reshape(traces, newshape=shape) - - di_input = np.zeros((batch_size, 384, 2, self.pre_frames + self.post_frames)) - di_label = np.zeros((batch_size, 384, 2, 1)) - for index_frame in range( - self.pre_frames + self.pre_post_omission, batch_size + self.pre_frames + self.pre_post_omission - ): - di_input[index_frame - self.pre_frames - self.pre_post_omission] = self.reshape_input_forward( - index_frame, traces - ) - di_label[index_frame - self.pre_frames - self.pre_post_omission] = self.reshape_label_forward( - traces[index_frame] - ) - return (di_input, di_label) - - def reshape_input_forward(self, index_frame, raw_data): - """Reshapes the frames surrounding the target frame to the form expected by model; - also subtracts mean and divides by std. - - Parameters - ---------- - index_frame : int - index of the frame to be predicted - raw_data : ndarray; (frames, 192, 2) - a chunk of data used to generate the input - - Returns - ------- - input_full : ndarray; (1, 384, 2, pre_frames+post_frames) - input to trained network to predict the center frame - """ - # currently only works for recordings with 384 channels - nb_probes = 384 - - # We reorganize to follow true geometry of probe for convolution - input_full = np.zeros([1, nb_probes, 2, self.pre_frames + self.post_frames], dtype="float32") - - input_index = np.arange( - index_frame - self.pre_frames - self.pre_post_omission, - index_frame + self.post_frames + self.pre_post_omission + 1, - ) - input_index = input_index[input_index != index_frame] - - for index_padding in np.arange(self.pre_post_omission + 1): - input_index = input_index[input_index != index_frame - index_padding] - input_index = input_index[input_index != index_frame + index_padding] - - data_img_input = raw_data[input_index, :, :] - - data_img_input = np.swapaxes(data_img_input, 1, 2) - data_img_input = np.swapaxes(data_img_input, 0, 2) - - even = np.arange(0, nb_probes, 2) - odd = even + 1 - - data_img_input = (data_img_input.astype("float32") - self.local_mean) / self.local_std - - input_full[0, even, 0, :] = data_img_input[:, 0, :] - input_full[0, odd, 1, :] = data_img_input[:, 1, :] - return input_full - def reshape_label_forward(self, label): - """Reshapes the target frame to the form expected by model. - - Parameters - ---------- - label : ndarray, (1, 192, 2) - - Returns - ------- - reshaped_label : ndarray, (1, 384, 2, 1) - target frame after reshaping - """ - # currently only works for recordings with 384 channels - nb_probes = 384 - - input_full = np.zeros([1, nb_probes, 2, 1], dtype="float32") - - data_img_input = np.expand_dims(label, axis=0) - data_img_input = np.swapaxes(data_img_input, 1, 2) - data_img_input = np.swapaxes(data_img_input, 0, 2) - - even = np.arange(0, nb_probes, 2) - odd = even + 1 - - data_img_input = (data_img_input.astype("float32") - self.local_mean) / self.local_std - - input_full[0, even, 0, :] = data_img_input[:, 0, :] - input_full[0, odd, 1, :] = data_img_input[:, 1, :] - return input_full - - return DeepInterpolationInputGenerator - - -class DeepInterpolatedRecording(BasePreprocessor): name = "deepinterpolate" def __init__( self, - recording: BaseRecording, + recording, model_path: str, - pre_frames: int = 30, - post_frames: int = 30, + pre_frame: int = 30, + post_frame: int = 30, pre_post_omission: int = 1, - batch_size=128, + batch_size: int = 128, use_gpu: bool = True, + predict_workers: int = 1, disable_tf_logger: bool = True, - **random_chunk_kwargs, + memory_gpu: Optional[int] = None, ): - """Applies DeepInterpolation, a neural network based denoising method, to the recording. - - Notes - ----- - * Currently this only works on Neuropixels 1.0-like recordings with 384 channels. - If the recording has fewer number of channels, consider matching the channel count with - `ZeroChannelPaddedRecording`. - * The specified model must have the same input dimensions as the model from original paper. - * Will use GPU if available. - * Inference (application of model) is done lazily, i.e. only when `get_traces` is called. - - For more information, see: - Lecoq et al. (2021) Removing independent noise in systems neuroscience data using DeepInterpolation. - Nature Methods. 18: 1401-1408. doi: 10.1038/s41592-021-01285-2. - - Parts of this code is adapted from https://github.com/AllenInstitute/deepinterpolation. + assert has_tf( + use_gpu, disable_tf_logger, memory_gpu + ), "To use DeepInterpolation, you first need to install `tensorflow`." - Parameters - ---------- - recording : si.BaseRecording - model_path : str - Path to pre-trained model - pre_frames : int - Number of frames before target frame used for training and inference - post_frames : int - Number of frames after target frame used for training and inference - pre_post_omission : int - Number of frames around the target frame to omit - batch_size : int, optional - Number of frames per batch to infer (adjust based on hardware); by default 128 - disable_tf_logger : bool, optional - If True, tensorflow logging is disabled, by default True - random_chunk_kwargs: keyword arguments for get_random_data_chunks - """ - - assert has_tf(use_gpu, disable_tf_logger), "To use DeepInterpolation, you first need to install `tensorflow`." - assert recording.get_num_channels() <= 384, ( - "DeepInterpolation only works on Neuropixels 1.0-like " - "recordings with 384 channels. This recording has too many " - "channels." - ) - assert recording.get_num_channels() == 384, ( - "DeepInterpolation only works on Neuropixels 1.0-like " - "recordings with 384 channels. " - "This recording has too few channels. Try matching the channel " - "count with `ZeroChannelPaddedRecording`." - ) - self.tf = import_tf(use_gpu, disable_tf_logger) + self.tf = import_tf(use_gpu, disable_tf_logger, memory_gpu=memory_gpu) # try move model load here with spawn BasePreprocessor.__init__(self, recording) # first time retrieving traces check that dimensions are ok self.tf.keras.backend.clear_session() - self.model = self.tf.keras.models.load_model(filepath=model_path) - # check input shape for the last dimension - config = self.model.get_config() - input_shape = config["layers"][0]["config"]["batch_input_shape"] - assert input_shape[-1] == pre_frames + post_frames, ( - "The sum of `pre_frames` and `post_frames` must match " "the last dimension of the model." - ) - - local_data = get_random_data_chunks(recording, **random_chunk_kwargs) - if isinstance(recording, ZeroChannelPaddedRecording): - local_data = local_data[:, recording.channel_mapping] - - local_mean = np.mean(local_data.flatten()) - local_std = np.std(local_data.flatten()) + model = self.tf.keras.models.load_model(filepath=model_path) + + # check shape (this will need to be done at inference) + network_input_shape = model.get_config()["layers"][0]["config"]["batch_input_shape"] + desired_shape = network_input_shape[1:3] + assert ( + desired_shape[0] * desired_shape[1] == recording.get_num_channels() + ), "The desired shape of the network input must match the number of channels in the recording" + assert ( + network_input_shape[-1] == pre_frame + post_frame + ), "The desired shape of the network input must match the pre and post frames" + self.model = model # add segment for segment in recording._recording_segments: recording_segment = DeepInterpolatedRecordingSegment( segment, self.model, - pre_frames, - post_frames, + pre_frame, + post_frame, pre_post_omission, - local_mean, - local_std, + desired_shape, batch_size, - use_gpu, - disable_tf_logger, + predict_workers, ) self.add_recording_segment(recording_segment) self._preferred_mp_context = "spawn" self._kwargs = dict( recording=recording, - model_path=model_path, - pre_frames=pre_frames, - post_frames=post_frames, + model_path=str(model_path), + pre_frame=pre_frame, + post_frame=post_frame, pre_post_omission=pre_post_omission, batch_size=batch_size, - **random_chunk_kwargs, + predict_workers=predict_workers, + use_gpu=use_gpu, + disable_tf_logger=disable_tf_logger, + memory_gpu=memory_gpu, ) self.extra_requirements.extend(["tensorflow"]) @@ -316,30 +123,26 @@ def __init__( self, recording_segment, model, - pre_frames, - post_frames, + pre_frame, + post_frame, pre_post_omission, - local_mean, - local_std, + desired_shape, batch_size, - use_gpu, - disable_tf_logger, + predict_workers, ): BasePreprocessorSegment.__init__(self, recording_segment) self.model = model - self.pre_frames = pre_frames - self.post_frames = post_frames + self.pre_frame = pre_frame + self.post_frame = post_frame self.pre_post_omission = pre_post_omission - self.local_mean = local_mean - self.local_std = local_std self.batch_size = batch_size - self.use_gpu = use_gpu - - # creating class dynamically to use the imported TF with GPU enabled/disabled based on the use_gpu flag - self.DeepInterpolationInputGenerator = define_input_generator_class(use_gpu, disable_tf_logger) + self.desired_shape = desired_shape + self.predict_workers = predict_workers def get_traces(self, start_frame, end_frame, channel_indices): + from .generators import SpikeInterfaceRecordingSegmentGenerator + n_frames = self.parent_recording_segment.get_num_samples() if start_frame == None: @@ -350,16 +153,16 @@ def get_traces(self, start_frame, end_frame, channel_indices): # for frames that lack full training data (i.e. pre and post frames including omissinos), # just return uninterpolated - if start_frame < self.pre_frames + self.pre_post_omission: - true_start_frame = self.pre_frames + self.pre_post_omission + if start_frame < self.pre_frame + self.pre_post_omission: + true_start_frame = self.pre_frame + self.pre_post_omission array_to_append_front = self.parent_recording_segment.get_traces( start_frame=0, end_frame=true_start_frame, channel_indices=channel_indices ) else: true_start_frame = start_frame - if end_frame > n_frames - self.post_frames - self.pre_post_omission: - true_end_frame = n_frames - self.post_frames - self.pre_post_omission + if end_frame > n_frames - self.post_frame - self.pre_post_omission: + true_end_frame = n_frames - self.post_frame - self.pre_post_omission array_to_append_back = self.parent_recording_segment.get_traces( start_frame=true_end_frame, end_frame=n_frames, channel_indices=channel_indices ) @@ -367,22 +170,24 @@ def get_traces(self, start_frame, end_frame, channel_indices): true_end_frame = end_frame # instantiate an input generator that can be passed directly to model.predict - input_generator = self.DeepInterpolationInputGenerator( - recording=self.parent_recording_segment, + batch_size = min(self.batch_size, true_end_frame - true_start_frame) + input_generator = SpikeInterfaceRecordingSegmentGenerator( + recording_segment=self.parent_recording_segment, start_frame=true_start_frame, end_frame=true_end_frame, - pre_frames=self.pre_frames, - post_frames=self.post_frames, + pre_frame=self.pre_frame, + post_frame=self.post_frame, pre_post_omission=self.pre_post_omission, - local_mean=self.local_mean, - local_std=self.local_std, - batch_size=self.batch_size, + batch_size=batch_size, + desired_shape=self.desired_shape, ) - di_output = self.model.predict(input_generator, verbose=2) + di_output = self.model.predict(input_generator, workers=self.predict_workers, verbose=2) - out_traces = self.reshape_backward(di_output) + out_traces = input_generator.reshape_output(di_output) - if true_start_frame != start_frame: + if ( + true_start_frame != start_frame + ): # related to the restriction to be applied from the start and end frames around 0 and end out_traces = np.concatenate((array_to_append_front, out_traces), axis=0) if true_end_frame != end_frame: @@ -390,31 +195,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): return out_traces[:, channel_indices] - def reshape_backward(self, di_frames): - """reshapes the prediction from model back to frames - - Parameters - ---------- - di_frames : ndarray, (frames, 384, 2, 1) - predicted output of the model - - Returns - ------- - reshaped_frames : ndarray; (frames, 384) - predicted frames after reshaping - """ - # currently works only for recording with 384 channels - nb_probes = 384 - n_frames = di_frames.shape[0] - even = np.arange(0, nb_probes, 2) - odd = even + 1 - reshaped_frames = np.zeros((n_frames, 384)) - for frame in range(n_frames): - reshaped_frames[frame, 0::2] = di_frames[frame, even, 0, 0] - reshaped_frames[frame, 1::2] = di_frames[frame, odd, 1, 0] - reshaped_frames = reshaped_frames * self.local_std + self.local_mean - return reshaped_frames - # function for API deepinterpolate = define_function_from_class(source_class=DeepInterpolatedRecording, name="deepinterpolate") diff --git a/src/spikeinterface/preprocessing/deepinterpolation/generators.py b/src/spikeinterface/preprocessing/deepinterpolation/generators.py new file mode 100644 index 0000000000..8200340ac1 --- /dev/null +++ b/src/spikeinterface/preprocessing/deepinterpolation/generators.py @@ -0,0 +1,311 @@ +from __future__ import annotations +import tempfile +import json +from typing import Optional +import numpy as np + +from ...core import load_extractor, concatenate_recordings, BaseRecording, BaseRecordingSegment + +from .tf_utils import has_tf, import_tf + +from ...core import load_extractor + +from deepinterpolation.generator_collection import SequentialGenerator + + +class SpikeInterfaceRecordingGenerator(SequentialGenerator): + """ + This generator is used when dealing with a SpikeInterface recording. + The desired shape controls the reshaping of the input data before convolutions. + """ + + def __init__( + self, + recordings: BaseRecording | list[BaseRecording], + pre_frame: int = 30, + post_frame: int = 30, + pre_post_omission: int = 1, + desired_shape: tuple = (192, 2), + batch_size: int = 100, + steps_per_epoch: int = 10, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + total_samples: int = -1, + ): + if not isinstance(recordings, list): + recordings = [recordings] + self.recordings = recordings + if len(recordings) > 1: + assert ( + r.get_num_channels() == recordings[0].get_num_channels() for r in recordings[1:] + ), "All recordings must have the same number of channels" + + total_num_samples = np.sum(r.get_total_samples() for r in recordings) + # In case of multiple recordings and/or multiple segments, we calculate the frame periods to be excluded (borders) + exclude_intervals = [] + pre_extended = pre_frame + pre_post_omission + post_extended = post_frame + pre_post_omission + for i, recording in enumerate(recordings): + total_samples_pre = np.sum(r.get_total_samples() for r in recordings[:i]) + for segment_index in range(recording.get_num_segments()): + # exclude samples at the border of the recordings + num_samples_segment_pre = np.sum(recording.get_num_samples(s) for s in np.arange(segment_index)) + if num_samples_segment_pre > 0: + exclude_intervals.append( + ( + total_samples_pre + num_samples_segment_pre - pre_extended - 1, + total_samples_pre + num_samples_segment_pre + post_extended, + ) + ) + # exclude samples at the border of the recordings + if total_samples_pre > 0: + exclude_intervals.append((total_samples_pre - pre_extended - 1, total_samples_pre + post_extended)) + + total_valid_samples = ( + total_num_samples - np.sum([end - start for start, end in exclude_intervals]) - pre_extended - post_extended + ) + self.total_samples = int(total_valid_samples) if total_samples == -1 else total_samples + assert len(desired_shape) == 2, "desired_shape should be 2D" + assert ( + desired_shape[0] * desired_shape[1] == recording.get_num_channels() + ), f"The product of desired_shape dimensions should be the number of channels: {recording.get_num_channels()}" + self.desired_shape = desired_shape + + start_frame = start_frame if start_frame is not None else 0 + end_frame = end_frame if end_frame is not None else total_num_samples + assert end_frame > start_frame, "end_frame must be greater than start_frame" + + sequential_generator_params = dict() + sequential_generator_params["steps_per_epoch"] = steps_per_epoch + sequential_generator_params["pre_frame"] = pre_frame + sequential_generator_params["post_frame"] = post_frame + sequential_generator_params["batch_size"] = batch_size + sequential_generator_params["start_frame"] = start_frame + sequential_generator_params["end_frame"] = end_frame + sequential_generator_params["total_samples"] = self.total_samples + sequential_generator_params["pre_post_omission"] = pre_post_omission + + json_path = tempfile.mktemp(suffix=".json") + with open(json_path, "w") as f: + json.dump(sequential_generator_params, f) + super().__init__(json_path) + + self._update_end_frame(total_num_samples) + + # self.list_samples will exclude the border intervals on the concat recording + self.recording_concat = concatenate_recordings(recordings) + self.exclude_intervals = exclude_intervals + self._calculate_list_samples(total_num_samples, exclude_intervals=exclude_intervals) + + self._kwargs = dict( + recordings=recordings, + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + start_frame=start_frame, + end_frame=end_frame, + ) + + # this is overridden to exclude samples from borders + def _calculate_list_samples(self, total_frame_per_movie, exclude_intervals=[]): + # We first cut if start and end frames are too close to the edges. + self.start_sample = np.max([self.pre_frame + self.pre_post_omission, self.start_frame]) + self.end_sample = np.min( + [ + self.end_frame, + total_frame_per_movie - 1 - self.post_frame - self.pre_post_omission, + ] + ) + + if (self.end_sample - self.start_sample + 1) < self.batch_size: + raise Exception( + "Not enough frames to construct one " + + str(self.batch_size) + + " frame(s) batch between " + + str(self.start_sample) + + " and " + + str(self.end_sample) + + " frame number." + ) + + # +1 to make sure end_samples is included + list_samples_all = np.arange(self.start_sample, self.end_sample + 1) + + if len(exclude_intervals) > 0: + for start, end in exclude_intervals: + list_samples_all = list_samples_all[(list_samples_all <= start) | (list_samples_all >= end)] + self.list_samples = list_samples_all + else: + self.list_samples = list_samples_all + + if self.randomize: + np.random.shuffle(self.list_samples) + + # We cut the number of samples if asked to + if self.total_samples > 0 and self.total_samples < len(self.list_samples): + self.list_samples = self.list_samples[0 : self.total_samples] + + def __getitem__(self, index): + # This is to ensure we are going through + # the entire data when steps_per_epoch start_frame, "end_frame must be greater than start_frame" + self.total_samples = end_frame - start_frame + + sequential_generator_params = dict() + sequential_generator_params["steps_per_epoch"] = steps_per_epoch + sequential_generator_params["pre_frame"] = pre_frame + sequential_generator_params["post_frame"] = post_frame + sequential_generator_params["batch_size"] = batch_size + sequential_generator_params["start_frame"] = start_frame + sequential_generator_params["end_frame"] = end_frame + sequential_generator_params["total_samples"] = self.total_samples + sequential_generator_params["pre_post_omission"] = pre_post_omission + + json_path = tempfile.mktemp(suffix=".json") + with open(json_path, "w") as f: + json.dump(sequential_generator_params, f) + super().__init__(json_path) + + self._update_end_frame(num_segment_samples) + # IMPORTANT: this is used for inference, so we don't want to shuffle + self.randomize = False + self._calculate_list_samples(num_segment_samples) + + def __getitem__(self, index): + # This is to ensure we are going through + # the entire data when steps_per_epoch