Skip to content

Commit

Permalink
Correct use of total_samples and expose training/test duration
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 5, 2023
1 parent 1b46ae6 commit e136048
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 26 deletions.
31 changes: 6 additions & 25 deletions src/spikeinterface/preprocessing/deepinterpolation/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,21 @@ def __init__(
steps_per_epoch=10,
start_frame=None,
end_frame=None,
total_samples=-1,
):
"Initialization"
assert recording.get_num_segments() == 1, "Only supported for mono-segment recordings"

self.recording = recording
self.total_samples = recording.get_num_samples()
num_samples = recording.get_num_samples()
self.total_samples = num_samples if total_samples == -1 else total_samples
assert len(desired_shape) == 2, "desired_shape should be 2D"
assert (
desired_shape[0] * desired_shape[1] == recording.get_num_channels()
), f"The product of desired_shape dimensions should be the number of channels: {recording.get_num_channels()}"
self.desired_shape = desired_shape

start_frame = start_frame if start_frame is not None else 0
end_frame = end_frame if end_frame is not None else self.total_samples
end_frame = end_frame if end_frame is not None else recording.get_num_samples()

assert end_frame > start_frame, "end_frame must be greater than start_frame"

Expand All @@ -59,8 +60,8 @@ def __init__(
json.dump(sequential_generator_params, f)
super().__init__(json_path)

self._update_end_frame(self.total_samples)
self._calculate_list_samples(self.total_samples)
self._update_end_frame(num_samples)
self._calculate_list_samples(num_samples)
self.last_batch_size = np.mod(self.end_frame - self.start_frame, self.batch_size)

self._kwargs = dict(
Expand All @@ -75,26 +76,6 @@ def __init__(
end_frame=end_frame,
)

def __len__(self):
"Denotes the total number of batches"
if self.last_batch_size == 0:
return int(len(self.list_samples) // self.batch_size)
else:
return int(len(self.list_samples) // self.batch_size) + 1

def generate_batch_indexes(self, index):
# This is to ensure we are going through
# the entire data when steps_per_epoch<self.__len__
if self.steps_per_epoch > 0:
index = index + self.steps_per_epoch * self.epoch_index
# Generate indexes of the batch
indexes = slice(index * self.batch_size, (index + 1) * self.batch_size)

if index == len(self) - 1 and self.last_batch_size > 0:
indexes = slice(-self.last_batch_size, len(self.list_samples))
shuffle_indexes = self.list_samples[indexes]
return shuffle_indexes

def __getitem__(self, index):
# This is to ensure we are going through
# the entire data when steps_per_epoch<self.__len__
Expand Down
22 changes: 21 additions & 1 deletion src/spikeinterface/preprocessing/deepinterpolation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ def train_deepinterpolation(
model_name: str,
train_start_s: float,
train_end_s: float,
train_duration_s: float,
test_start_s: float,
test_end_s: float,
test_duration_s: float,
desired_shape: tuple[int, int],
pre_frame: int = 30,
post_frame: int = 30,
Expand All @@ -37,7 +39,7 @@ def train_deepinterpolation(
apply_learning_decay: int = 0,
nb_times_through_data: int = 1,
learning_rate: float = 0.0001,
loss: str = "mean_absolute_error",
loss: str = "mean_squared_error",
nb_workers: int = -1,
caching_validation: bool = False,
run_uid: str = "si",
Expand All @@ -61,10 +63,14 @@ def train_deepinterpolation(
Start time of the training in seconds
train_end_s : float
End time of the training in seconds
train_duration_s : float
Duration of the training in seconds
test_start_s : float
Start time of the testing in seconds
test_end_s : float
End time of the testing in seconds
test_duration_s : float
Duration of the testing in seconds
desired_shape : tuple
Shape of the input to the network
pre_frame : int
Expand Down Expand Up @@ -119,8 +125,10 @@ def train_deepinterpolation(
model_name,
train_start_s,
train_end_s,
train_duration_s,
test_start_s,
test_end_s,
test_duration_s,
pre_frame,
post_frame,
pre_post_omission,
Expand Down Expand Up @@ -156,8 +164,10 @@ def train_deepinterpolation_process(
model_name: str,
train_start_s: float,
train_end_s: float,
train_duration_s: float | None,
test_start_s: float,
test_end_s: float,
test_duration_s: float | None,
pre_frame: int,
post_frame: int,
pre_post_omission: int,
Expand Down Expand Up @@ -195,8 +205,16 @@ def train_deepinterpolation_process(
### Define params
start_frame_training = int(train_start_s * fs)
end_frame_training = int(train_end_s * fs)
if train_duration_s is not None:
total_samples_training = int(train_duration_s * fs)
else:
total_samples_training = -1
start_frame_test = int(test_start_s * fs)
end_frame_test = int(test_end_s * fs)
if test_duration_s is not None:
total_samples_testing = int(test_duration_s * fs)
else:
total_samples_testing = -1

# Those are parameters used for the network topology
network_params = dict()
Expand Down Expand Up @@ -232,6 +250,7 @@ def train_deepinterpolation_process(
start_frame=start_frame_training,
end_frame=end_frame_training,
desired_shape=desired_shape,
total_samples=total_samples_training,
)
test_data_generator = SpikeInterfaceRecordingGenerator(
recording,
Expand All @@ -242,6 +261,7 @@ def train_deepinterpolation_process(
end_frame=end_frame_test,
steps_per_epoch=-1,
desired_shape=desired_shape,
total_samples=total_samples_testing,
)

network_json_path = trained_model_folder / "network_params.json"
Expand Down

0 comments on commit e136048

Please sign in to comment.