From 22d025d1d1ac61cd6269114ac6dff48915dcccc2 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 1 Jul 2024 21:21:42 +0100 Subject: [PATCH 1/5] Add time vector case to 'get_durations'. --- src/spikeinterface/core/baserecording.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index aab7577b31..39531dd204 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -233,8 +233,14 @@ def get_duration(self, segment_index=None) -> float: The duration in seconds """ segment_index = self._check_segment_index(segment_index) - segment_num_samples = self.get_num_samples(segment_index=segment_index) - segment_duration = segment_num_samples / self.get_sampling_frequency() + + if self.has_time_vector(segment_index): + times = self.get_times(segment_index) + segment_duration = times[-1] - times[0] + (1 / self.get_sampling_frequency()) + else: + segment_num_samples = self.get_num_samples(segment_index=segment_index) + segment_duration = segment_num_samples / self.get_sampling_frequency() + return segment_duration def get_total_duration(self) -> float: @@ -246,7 +252,7 @@ def get_total_duration(self) -> float: float The duration in seconds """ - duration = self.get_total_samples() / self.get_sampling_frequency() + duration = sum([self.get_duration(idx) for idx in range(self.get_num_segments())]) return duration def get_memory_size(self, segment_index=None) -> int: From 27360da836fa28d038beef3f8f686861b0033654 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 12 Jul 2024 13:13:30 +0100 Subject: [PATCH 2/5] Extending sorting_analyzer 'get_total_duration' to use recording if available. --- src/spikeinterface/core/sortinganalyzer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index fc20029ce6..5e1856e7cd 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -821,7 +821,10 @@ def get_total_samples(self) -> int: return s def get_total_duration(self) -> float: - duration = self.get_total_samples() / self.sampling_frequency + if self.has_recording() or self.has_temporary_recording(): + duration = self.recording.get_total_duration() + else: + duration = self.get_total_samples() / self.sampling_frequency return duration def get_num_channels(self) -> int: From 31b94df9d7afb6f5efe2e081ecac1ae52a8280da Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 15 Jul 2024 13:59:35 +0100 Subject: [PATCH 3/5] Start adding tests. --- .../core/tests/test_time_handling.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 049d5ab6e5..cd329e32b1 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -243,6 +243,47 @@ def test_slice_recording(self, time_type, bounds): assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + def test_get_durations(self, time_vector_recording, t_start_recording): + """ + Test the `get_durations` functions that return the total duration + for a segment. Test that it is correct after adding both `t_start` + or `time_vector` to the recording. + """ + raw_recording, tvector_recording, all_time_vectors = time_vector_recording + _, tstart_recording, all_t_starts = t_start_recording + + ts = 1 / raw_recording.get_sampling_frequency() + + all_raw_durations = [] + all_vector_durations = [] + for segment_index in range(raw_recording.get_num_segments()): + + # Test before `t_start` and `t_start` (`t_start` is just an offset, + # should not affect duration). + raw_duration = all_t_starts[segment_index][-1] - all_t_starts[segment_index][0] + ts + + assert np.isclose(raw_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8) + assert np.isclose(tstart_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8) + + # Test the duration from the time vector. + vector_duration = all_time_vectors[segment_index][-1] - all_time_vectors[segment_index][0] + ts + + assert tvector_recording.get_duration(segment_index) == vector_duration + + all_raw_durations.append(raw_duration) + all_vector_durations.append(vector_duration) + + # Finally test the total recording duration + assert np.isclose(tstart_recording.get_total_duration(), sum(all_raw_durations), rtol=0, atol=1e-8) + assert np.isclose(tvector_recording.get_total_duration(), sum(all_vector_durations), rtol=0, atol=1e-8) + + # def test_sorting_analyzer_get_durations(self, time_vector_recording): + # """ """ + # breakpoint() + # sorting = si.generate_sorting() + # sorting_analyzer = si.create_sorting_analyzer(sorting, recording=None) + # si.sorting_an + # Helpers #### def _check_times_match(self, recording, all_times): """ From d69d578921b21e9be9b2b112fbe0b9d82b8be4ee Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 15 Jul 2024 15:11:25 +0200 Subject: [PATCH 4/5] Use num_samples instead of durations in hybrid tools --- src/spikeinterface/generation/hybrid_tools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 8f2ef0ec21..2806754c9d 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -400,6 +400,7 @@ def generate_hybrid_recording( num_segments = recording.get_num_segments() dtype = recording.dtype durations = np.array([recording.get_duration(segment_index) for segment_index in range(num_segments)]) + num_samples = np.array([recording.get_num_samples(segment_index) for segment_index in range(num_segments)]) channel_locations = probe.contact_positions assert ( @@ -548,7 +549,7 @@ def generate_hybrid_recording( displacement_vectors=displacement_vectors, displacement_sampling_frequency=displacement_sampling_frequency, displacement_unit_factor=displacement_unit_factor, - num_samples=(np.array(durations) * sampling_frequency).astype("int64"), + num_samples=num_samples.astype("int64"), amplitude_factor=amplitude_factor, ) From 6b8c540755b0d67f240aa6e0a465c79bddfb67a6 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 15 Jul 2024 14:23:54 +0100 Subject: [PATCH 5/5] Add test for sorting analzyer total duration. --- .../core/tests/test_time_handling.py | 38 ++++++++++++++++--- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index cd329e32b1..1b570091be 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -277,12 +277,38 @@ def test_get_durations(self, time_vector_recording, t_start_recording): assert np.isclose(tstart_recording.get_total_duration(), sum(all_raw_durations), rtol=0, atol=1e-8) assert np.isclose(tvector_recording.get_total_duration(), sum(all_vector_durations), rtol=0, atol=1e-8) - # def test_sorting_analyzer_get_durations(self, time_vector_recording): - # """ """ - # breakpoint() - # sorting = si.generate_sorting() - # sorting_analyzer = si.create_sorting_analyzer(sorting, recording=None) - # si.sorting_an + def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recording): + """ + Test that when a recording is set on `sorting_analyzer`, the + total duration is propagated from the recording to the + `sorting_analyzer.get_total_duration()` function. + """ + _, times_recording, _ = time_vector_recording + + sorting = si.generate_sorting( + durations=[times_recording.get_duration(s) for s in range(times_recording.get_num_segments())] + ) + sorting_analyzer = si.create_sorting_analyzer(sorting, recording=times_recording) + + assert np.array_equal(sorting_analyzer.get_total_duration(), times_recording.get_total_duration()) + + def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording): + """ + Test when the `sorting_analzyer` does not have a recording set, + the total duration is calculated on the fly from num samples and + sampling frequency (thus matching `raw_recording` with no times set + that uses the same method to calculate the total duration). + """ + raw_recording, _, _ = time_vector_recording + + sorting = si.generate_sorting( + durations=[raw_recording.get_duration(s) for s in range(raw_recording.get_num_segments())] + ) + sorting_analyzer = si.create_sorting_analyzer(sorting, recording=raw_recording) + + sorting_analyzer._recording = None + + assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration()) # Helpers #### def _check_times_match(self, recording, all_times):