diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 2313e7d253..02f4529a98 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -47,7 +47,7 @@ def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.arra return spike_trains -def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array): +def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, absolute_index: bool = False): """ Similar to spike_vector_to_spike_trains but instead having the spike_trains (aka spike times) return spike indices by segment and units. @@ -61,6 +61,11 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array): List of spike vectors optained with sorting.to_spike_vector(concatenated=False) unit_ids: np.array Unit ids + absolute_index: bool, default False + It True, return absolute spike indices. If False, spike indices are relative to the segment. + When a unique spike vector is used, then absolute_index should be True. + When a list of spikes per segment is used, then absolute_index should be False. + Returns ------- spike_indices: dict[dict]: @@ -82,10 +87,16 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array): num_units = unit_ids.size spike_indices = {} + + total_spikes = 0 for segment_index, spikes in enumerate(spike_vector): indices = np.arange(spikes.size, dtype=np.int64) + if absolute_index: + indices += total_spikes + total_spikes += spikes.size unit_indices = np.array(spikes["unit_index"]).astype(np.int64, copy=False) list_of_spike_indices = vector_to_list_of_spiketrain(indices, unit_indices, num_units) + spike_indices[segment_index] = dict(zip(unit_ids, list_of_spike_indices)) return spike_indices diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 09b46362e5..aebfd1fd78 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -127,7 +127,7 @@ def _get_data(self, outputs="numpy"): elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) amplitudes_by_units = {} for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): amplitudes_by_units[segment_index] = {} diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 96e01a68c4..52a91342b6 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -138,7 +138,7 @@ def _get_data(self, outputs="numpy"): elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) spike_locations_by_units = {} for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): spike_locations_by_units[segment_index] = {}