Skip to content

Commit

Permalink
Add option to use recording lists, test_recording, and multi-segment …
Browse files Browse the repository at this point in the history
…+ correct usage of total_samples
  • Loading branch information
alejoe91 committed Sep 6, 2023
1 parent 9c312e9 commit 5e3ae74
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,6 @@ def get_traces(self, start_frame, end_frame, channel_indices):
batch_size=batch_size,
desired_shape=self.desired_shape,
)
input_generator.randomize = False
input_generator._calculate_list_samples(input_generator.total_samples)
di_output = self.model.predict(input_generator, workers=self.predict_workers, verbose=2)

out_traces = input_generator.reshape_output(di_output)
Expand Down
178 changes: 114 additions & 64 deletions src/spikeinterface/preprocessing/deepinterpolation/generators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations
import tempfile
import json
from typing import Any
from typing import Optional
import numpy as np

from ...core import load_extractor
from ...core import load_extractor, concatenate_recordings, BaseRecording, BaseRecordingSegment

from .tf_utils import has_tf, import_tf

Expand All @@ -13,36 +14,65 @@


class SpikeInterfaceRecordingGenerator(SequentialGenerator):
"""This generator is used when dealing with a SpikeInterface recording.
The desired shape controls the reshaping of the input data before convolutions."""
"""
This generator is used when dealing with a SpikeInterface recording.
The desired shape controls the reshaping of the input data before convolutions.
"""

def __init__(
self,
recording,
pre_frame=30,
post_frame=30,
pre_post_omission=1,
desired_shape=(192, 2),
batch_size=100,
steps_per_epoch=10,
start_frame=None,
end_frame=None,
total_samples=-1,
recordings: BaseRecording | list[BaseRecording],
pre_frame: int = 30,
post_frame: int = 30,
pre_post_omission: int = 1,
desired_shape: tuple = (192, 2),
batch_size: int = 100,
steps_per_epoch: int = 10,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
total_samples: int = -1,
):
assert recording.get_num_segments() == 1, "Only supported for mono-segment recordings"

self.recording = recording
num_samples = recording.get_num_samples()
self.total_samples = num_samples if total_samples == -1 else total_samples
if not isinstance(recordings, list):
recordings = [recordings]
self.recordings = recordings
if len(recordings) > 1:
assert (
r.get_num_channels() == recordings[0].get_num_channels() for r in recordings[1:]
), "All recordings must have the same number of channels"

total_num_samples = np.sum(r.get_total_samples() for r in recordings)
# In case of multiple recordings and/or multiple segments, we calculate the frame periods to be excluded (borders)
exclude_intervals = []
pre_extended = pre_frame + pre_post_omission
post_extended = post_frame + pre_post_omission
for i, recording in enumerate(recordings):
total_samples_pre = np.sum(r.get_total_samples() for r in recordings[:i])
for segment_index in range(recording.get_num_segments()):
# exclude samples at the border of the recordings
num_samples_segment_pre = np.sum(recording.get_num_samples(s) for s in np.arange(segment_index))
if num_samples_segment_pre > 0:
exclude_intervals.append(
(
total_samples_pre + num_samples_segment_pre - pre_extended - 1,
total_samples_pre + num_samples_segment_pre + post_extended,
)
)
# exclude samples at the border of the recordings
if total_samples_pre > 0:
exclude_intervals.append((total_samples_pre - pre_extended - 1, total_samples_pre + post_extended))

total_valid_samples = (
total_num_samples - np.sum([end - start for start, end in exclude_intervals]) - pre_extended - post_extended
)
self.total_samples = int(total_valid_samples) if total_samples == -1 else total_samples
assert len(desired_shape) == 2, "desired_shape should be 2D"
assert (
desired_shape[0] * desired_shape[1] == recording.get_num_channels()
), f"The product of desired_shape dimensions should be the number of channels: {recording.get_num_channels()}"
self.desired_shape = desired_shape

start_frame = start_frame if start_frame is not None else 0
end_frame = end_frame if end_frame is not None else recording.get_num_samples()

end_frame = end_frame if end_frame is not None else total_num_samples
assert end_frame > start_frame, "end_frame must be greater than start_frame"

sequential_generator_params = dict()
Expand All @@ -60,12 +90,15 @@ def __init__(
json.dump(sequential_generator_params, f)
super().__init__(json_path)

self._update_end_frame(num_samples)
self._calculate_list_samples(num_samples)
self.last_batch_size = np.mod(self.end_frame - self.start_frame, self.batch_size)
self._update_end_frame(total_num_samples)

# self.list_samples will exclude the border intervals on the concat recording
self.recording_concat = concatenate_recordings(recordings)
self.exclude_intervals = exclude_intervals
self._calculate_list_samples(total_num_samples, exclude_intervals=exclude_intervals)

self._kwargs = dict(
recording=recording,
recordings=recordings,
pre_frame=pre_frame,
post_frame=post_frame,
pre_post_omission=pre_post_omission,
Expand All @@ -76,6 +109,45 @@ def __init__(
end_frame=end_frame,
)

# this is overridden to exclude samples from borders
def _calculate_list_samples(self, total_frame_per_movie, exclude_intervals=[]):
# We first cut if start and end frames are too close to the edges.
self.start_sample = np.max([self.pre_frame + self.pre_post_omission, self.start_frame])
self.end_sample = np.min(
[
self.end_frame,
total_frame_per_movie - 1 - self.post_frame - self.pre_post_omission,
]
)

if (self.end_sample - self.start_sample + 1) < self.batch_size:
raise Exception(
"Not enough frames to construct one "
+ str(self.batch_size)
+ " frame(s) batch between "
+ str(self.start_sample)
+ " and "
+ str(self.end_sample)
+ " frame number."
)

# +1 to make sure end_samples is included
list_samples_all = np.arange(self.start_sample, self.end_sample + 1)

if len(exclude_intervals) > 0:
for start, end in exclude_intervals:
list_samples_all = list_samples_all[(list_samples_all <= start) | (list_samples_all >= end)]
self.list_samples = list_samples_all
else:
self.list_samples = list_samples_all

if self.randomize:
np.random.shuffle(self.list_samples)

# We cut the number of samples if asked to
if self.total_samples > 0 and self.total_samples < len(self.list_samples):
self.list_samples = self.list_samples[0 : self.total_samples]

def __getitem__(self, index):
# This is to ensure we are going through
# the entire data when steps_per_epoch<self.__len__
Expand Down Expand Up @@ -106,7 +178,7 @@ def __data_generation__(self, index_frame):

start_frame = index_frame - self.pre_frame - self.pre_post_omission
end_frame = index_frame + self.post_frame + self.pre_post_omission + 1
full_traces = self.recording.get_traces(start_frame=start_frame, end_frame=end_frame).astype("float32")
full_traces = self.recording_concat.get_traces(start_frame=start_frame, end_frame=end_frame).astype("float32")

if full_traces.shape[0] == 0:
print(f"Error! {index_frame}-{start_frame}-{end_frame}", flush=True)
Expand Down Expand Up @@ -140,28 +212,26 @@ class SpikeInterfaceRecordingSegmentGenerator(SequentialGenerator):

def __init__(
self,
recording_segment,
start_frame,
end_frame,
pre_frame=30,
post_frame=30,
pre_post_omission=1,
desired_shape=(192, 2),
batch_size=100,
steps_per_epoch=10,
recording_segment: BaseRecordingSegment,
start_frame: int,
end_frame: int,
pre_frame: int = 30,
post_frame: int = 30,
pre_post_omission: int = 1,
desired_shape: tuple = (192, 2),
batch_size: int = 100,
steps_per_epoch: int = 10,
):
"Initialization"

self.recording_segment = recording_segment
self.total_samples = recording_segment.get_num_samples()
self.num_channels = int(desired_shape[0] * desired_shape[1])
assert len(desired_shape) == 2, "desired_shape should be 2D"
self.desired_shape = desired_shape

num_segment_samples = recording_segment.get_num_samples()
start_frame = start_frame if start_frame is not None else 0
end_frame = end_frame if end_frame is not None else self.total_samples

end_frame = end_frame if end_frame is not None else num_segment_samples
assert end_frame > start_frame, "end_frame must be greater than start_frame"
self.total_samples = end_frame - start_frame

sequential_generator_params = dict()
sequential_generator_params["steps_per_epoch"] = steps_per_epoch
Expand All @@ -178,30 +248,10 @@ def __init__(
json.dump(sequential_generator_params, f)
super().__init__(json_path)

self._update_end_frame(self.total_samples)
self._calculate_list_samples(self.total_samples)
self.last_batch_size = np.mod(self.end_frame - self.start_frame, self.batch_size)

def __len__(self):
"Denotes the total number of batches"
if self.last_batch_size == 0:
return int(len(self.list_samples) // self.batch_size)
else:
self.has_partial_batch = True
return int(len(self.list_samples) // self.batch_size) + 1

def generate_batch_indexes(self, index):
# This is to ensure we are going through
# the entire data when steps_per_epoch<self.__len__
if self.steps_per_epoch > 0:
index = index + self.steps_per_epoch * self.epoch_index
# Generate indexes of the batch
indexes = slice(index * self.batch_size, (index + 1) * self.batch_size)

if index == len(self) - 1 and self.last_batch_size > 0:
indexes = slice(-self.last_batch_size, len(self.list_samples))
shuffle_indexes = self.list_samples[indexes]
return shuffle_indexes
self._update_end_frame(num_segment_samples)
# IMPORTANT: this is used for inference, so we don't want to shuffle
self.randomize = False
self._calculate_list_samples(num_segment_samples)

def __getitem__(self, index):
# This is to ensure we are going through
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
from pathlib import Path

import probeinterface as pi
from spikeinterface import download_dataset, generate_recording
from spikeinterface import download_dataset, generate_recording, append_recordings, concatenate_recordings
from spikeinterface.extractors import read_mearec, read_spikeglx, read_openephys
from spikeinterface.preprocessing import depth_order, zscore

from spikeinterface.preprocessing.deepinterpolation import train_deepinterpolation, deepinterpolate
from spikeinterface.preprocessing.deepinterpolation import train_deepinterpolation, deepinterpolate


if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "preprocessing"
cache_folder = pytest.global_test_folder / "deepinterpolation"
else:
cache_folder = Path("cache_folder") / "preprocessing"
cache_folder = Path("cache_folder") / "deepinterpolation"
if not cache_folder.is_dir():
cache_folder.mkdir(parents=True)


def recording_and_shape():
Expand All @@ -34,6 +37,27 @@ def recording_and_shape_fixture():
return recording_and_shape()


def test_deepinterpolation_generator_borders(recording_and_shape_fixture):
"""Test that the generator avoids borders in multi-segment and recording lists cases"""
from spikeinterface.preprocessing.deepinterpolation.generators import SpikeInterfaceRecordingGenerator

recording, desired_shape = recording_and_shape_fixture
recording_multi_segment = append_recordings([recording, recording, recording])
recording_list = [recording, recording, recording]
recording_multi_list = [recording_multi_segment, recording_multi_segment, recording_multi_segment]

gen_multi_segment = SpikeInterfaceRecordingGenerator(recording_multi_segment, desired_shape=desired_shape)
gen_list = SpikeInterfaceRecordingGenerator(recording_list, desired_shape=desired_shape)
gen_multi_list = SpikeInterfaceRecordingGenerator(recording_multi_list, desired_shape=desired_shape)

# exclude in between segments
assert len(gen_multi_segment.exclude_intervals) == 2
# exclude in between recordings
assert len(gen_list.exclude_intervals) == 2
# exclude in between recordings and segments
assert len(gen_multi_list.exclude_intervals) == 2 * len(recording_multi_list) + 2


def test_deepinterpolation_training(recording_and_shape_fixture):
recording, desired_shape = recording_and_shape_fixture

Expand All @@ -45,10 +69,10 @@ def test_deepinterpolation_training(recording_and_shape_fixture):
model_name="training",
train_start_s=1,
train_end_s=10,
train_duration_s=0.1,
train_duration_s=0.02,
test_start_s=0,
test_end_s=1,
test_duration_s=0.05,
test_duration_s=0.01,
pre_frame=20,
post_frame=20,
run_uid="si_test",
Expand All @@ -73,10 +97,10 @@ def test_deepinterpolation_transfer(recording_and_shape_fixture, tmp_path):
existing_model_path=existing_model_path,
train_start_s=1,
train_end_s=10,
train_duration_s=0.1,
train_duration_s=0.02,
test_start_s=0,
test_end_s=1,
test_duration_s=0.05,
test_duration_s=0.01,
pre_frame=20,
post_frame=20,
pre_post_omission=1,
Expand All @@ -94,6 +118,7 @@ def test_deepinterpolation_inference(recording_and_shape_fixture):
recording_di = deepinterpolate(
recording, model_path=existing_model_path, pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=1
)
print(recording_di)
traces_original_first = recording.get_traces(start_frame=0, end_frame=100)
traces_di_first = recording_di.get_traces(start_frame=0, end_frame=100)
assert traces_di_first.shape == (100, recording.get_num_channels())
Expand Down Expand Up @@ -123,9 +148,19 @@ def test_deepinterpolation_inference_multi_job(recording_and_shape_fixture):
pre_post_omission=1,
use_gpu=False,
)
print(recording_di)
recording_di_slice = recording_di.frame_slice(start_frame=0, end_frame=int(0.5 * recording.sampling_frequency))

recording_di_slice.save(folder=Path(cache_folder) / "di_slice", n_jobs=2, chunk_duration="50ms")
traces_chunks = recording_di_slice.get_traces()
traces_full = recording_di_slice.get_traces()
np.testing.assert_array_equal(traces_chunks, traces_full)


if __name__ == "__main__":
recording_shape = recording_and_shape()
# test_deepinterpolation_training(recording_shape)
# test_deepinterpolation_transfer()
test_deepinterpolation_inference(recording_shape)
# test_deepinterpolation_inference_multi_job()
# test_deepinterpolation_generator_borders(recording_shape)
Loading

0 comments on commit 5e3ae74

Please sign in to comment.