Skip to content

Commit

Permalink
Merge pull request #3118 from JoeZiminski/add_time_vector_case_to_get…
Browse files Browse the repository at this point in the history
…_duration

Add time vector case to `get_durations`.
  • Loading branch information
samuelgarcia authored Jul 15, 2024
2 parents e2fe22e + 7714724 commit a14ee81
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 5 deletions.
12 changes: 9 additions & 3 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,14 @@ def get_duration(self, segment_index=None) -> float:
The duration in seconds
"""
segment_index = self._check_segment_index(segment_index)
segment_num_samples = self.get_num_samples(segment_index=segment_index)
segment_duration = segment_num_samples / self.get_sampling_frequency()

if self.has_time_vector(segment_index):
times = self.get_times(segment_index)
segment_duration = times[-1] - times[0] + (1 / self.get_sampling_frequency())
else:
segment_num_samples = self.get_num_samples(segment_index=segment_index)
segment_duration = segment_num_samples / self.get_sampling_frequency()

return segment_duration

def get_total_duration(self) -> float:
Expand All @@ -246,7 +252,7 @@ def get_total_duration(self) -> float:
float
The duration in seconds
"""
duration = self.get_total_samples() / self.get_sampling_frequency()
duration = sum([self.get_duration(idx) for idx in range(self.get_num_segments())])
return duration

def get_memory_size(self, segment_index=None) -> int:
Expand Down
5 changes: 4 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,10 @@ def get_total_samples(self) -> int:
return s

def get_total_duration(self) -> float:
duration = self.get_total_samples() / self.sampling_frequency
if self.has_recording() or self.has_temporary_recording():
duration = self.recording.get_total_duration()
else:
duration = self.get_total_samples() / self.sampling_frequency
return duration

def get_num_channels(self) -> int:
Expand Down
67 changes: 67 additions & 0 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,73 @@ def test_slice_recording(self, time_type, bounds):

assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)

def test_get_durations(self, time_vector_recording, t_start_recording):
"""
Test the `get_durations` functions that return the total duration
for a segment. Test that it is correct after adding both `t_start`
or `time_vector` to the recording.
"""
raw_recording, tvector_recording, all_time_vectors = time_vector_recording
_, tstart_recording, all_t_starts = t_start_recording

ts = 1 / raw_recording.get_sampling_frequency()

all_raw_durations = []
all_vector_durations = []
for segment_index in range(raw_recording.get_num_segments()):

# Test before `t_start` and `t_start` (`t_start` is just an offset,
# should not affect duration).
raw_duration = all_t_starts[segment_index][-1] - all_t_starts[segment_index][0] + ts

assert np.isclose(raw_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)
assert np.isclose(tstart_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)

# Test the duration from the time vector.
vector_duration = all_time_vectors[segment_index][-1] - all_time_vectors[segment_index][0] + ts

assert tvector_recording.get_duration(segment_index) == vector_duration

all_raw_durations.append(raw_duration)
all_vector_durations.append(vector_duration)

# Finally test the total recording duration
assert np.isclose(tstart_recording.get_total_duration(), sum(all_raw_durations), rtol=0, atol=1e-8)
assert np.isclose(tvector_recording.get_total_duration(), sum(all_vector_durations), rtol=0, atol=1e-8)

def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recording):
"""
Test that when a recording is set on `sorting_analyzer`, the
total duration is propagated from the recording to the
`sorting_analyzer.get_total_duration()` function.
"""
_, times_recording, _ = time_vector_recording

sorting = si.generate_sorting(
durations=[times_recording.get_duration(s) for s in range(times_recording.get_num_segments())]
)
sorting_analyzer = si.create_sorting_analyzer(sorting, recording=times_recording)

assert np.array_equal(sorting_analyzer.get_total_duration(), times_recording.get_total_duration())

def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording):
"""
Test when the `sorting_analzyer` does not have a recording set,
the total duration is calculated on the fly from num samples and
sampling frequency (thus matching `raw_recording` with no times set
that uses the same method to calculate the total duration).
"""
raw_recording, _, _ = time_vector_recording

sorting = si.generate_sorting(
durations=[raw_recording.get_duration(s) for s in range(raw_recording.get_num_segments())]
)
sorting_analyzer = si.create_sorting_analyzer(sorting, recording=raw_recording)

sorting_analyzer._recording = None

assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration())

# Helpers ####
def _check_times_match(self, recording, all_times):
"""
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/generation/hybrid_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def generate_hybrid_recording(
num_segments = recording.get_num_segments()
dtype = recording.dtype
durations = np.array([recording.get_duration(segment_index) for segment_index in range(num_segments)])
num_samples = np.array([recording.get_num_samples(segment_index) for segment_index in range(num_segments)])
channel_locations = probe.contact_positions

assert (
Expand Down Expand Up @@ -548,7 +549,7 @@ def generate_hybrid_recording(
displacement_vectors=displacement_vectors,
displacement_sampling_frequency=displacement_sampling_frequency,
displacement_unit_factor=displacement_unit_factor,
num_samples=(np.array(durations) * sampling_frequency).astype("int64"),
num_samples=num_samples.astype("int64"),
amplitude_factor=amplitude_factor,
)

Expand Down

0 comments on commit a14ee81

Please sign in to comment.