Skip to content

Commit

Permalink
Merge pull request #3120 from JoeZiminski/fix_save_to_memory_t_start
Browse files Browse the repository at this point in the history
Fix `t_starts` not propagated to `save_to_memory`.
  • Loading branch information
alejoe91 authored Jul 9, 2024
2 parents f273020 + 30606cc commit 00e1cf9
Show file tree
Hide file tree
Showing 3 changed files with 292 additions and 57 deletions.
27 changes: 19 additions & 8 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,24 +498,35 @@ def time_to_sample_index(self, time_s, segment_index=None):
rs = self._recording_segments[segment_index]
return rs.time_to_sample_index(time_s)

def _save(self, format="binary", verbose: bool = False, **save_kwargs):
def _get_t_starts(self):
# handle t_starts
t_starts = []
has_time_vectors = []
for segment_index, rs in enumerate(self._recording_segments):
for rs in self._recording_segments:
d = rs.get_times_kwargs()
t_starts.append(d["t_start"])
has_time_vectors.append(d["time_vector"] is not None)

if all(t_start is None for t_start in t_starts):
t_starts = None
return t_starts

def _get_time_vectors(self):
time_vectors = []
for rs in self._recording_segments:
d = rs.get_times_kwargs()
time_vectors.append(d["time_vector"])
if all(time_vector is None for time_vector in time_vectors):
time_vectors = None
return time_vectors

def _save(self, format="binary", verbose: bool = False, **save_kwargs):
kwargs, job_kwargs = split_job_kwargs(save_kwargs)

if format == "binary":
folder = kwargs["folder"]
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
dtype = kwargs.get("dtype", None) or self.get_dtype()
t_starts = self._get_t_starts()

write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs)

Expand Down Expand Up @@ -572,11 +583,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
probegroup = self.get_probegroup()
cached.set_probegroup(probegroup)

for segment_index, rs in enumerate(self._recording_segments):
d = rs.get_times_kwargs()
time_vector = d["time_vector"]
if time_vector is not None:
cached._recording_segments[segment_index].time_vector = time_vector
time_vectors = self._get_time_vectors()
if time_vectors is not None:
for segment_index, time_vector in enumerate(time_vectors):
if time_vector is not None:
cached.set_times(time_vector, segment_index=segment_index)

return cached

Expand Down
12 changes: 8 additions & 4 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
@staticmethod
def from_recording(source_recording, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs)

t_starts = source_recording._get_t_starts()

if shms[0] is not None:
# if the computation was done in parallel then traces_list is shared array
# this can lead to problem
Expand All @@ -91,13 +94,14 @@ def from_recording(source_recording, **job_kwargs):
for shm in shms:
shm.close()
shm.unlink()
# TODO later : propagte t_starts ?

recording = NumpyRecording(
traces_list,
source_recording.get_sampling_frequency(),
t_starts=None,
t_starts=t_starts,
channel_ids=source_recording.channel_ids,
)
return recording


class NumpyRecordingSegment(BaseRecordingSegment):
Expand Down Expand Up @@ -206,15 +210,15 @@ def __del__(self):
def from_recording(source_recording, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs)

# TODO later : propagte t_starts ?
t_starts = source_recording._get_t_starts()

recording = SharedMemoryRecording(
shm_names=[shm.name for shm in shms],
shape_list=[traces.shape for traces in traces_list],
dtype=source_recording.dtype,
sampling_frequency=source_recording.sampling_frequency,
channel_ids=source_recording.channel_ids,
t_starts=None,
t_starts=t_starts,
main_shm_owner=True,
)

Expand Down
Loading

0 comments on commit 00e1cf9

Please sign in to comment.