Skip to content

Commit

Permalink
Merge pull request #3048 from samuelgarcia/fix_vector_indices
Browse files Browse the repository at this point in the history
fix spike_vector_to_indices()
  • Loading branch information
alejoe91 authored Jun 21, 2024
2 parents 44d8e33 + 391db33 commit 2dc0b74
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
13 changes: 12 additions & 1 deletion src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.arra
return spike_trains


def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array):
def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, absolute_index: bool = False):
"""
Similar to spike_vector_to_spike_trains but instead having the spike_trains (aka spike times) return
spike indices by segment and units.
Expand All @@ -61,6 +61,11 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array):
List of spike vectors optained with sorting.to_spike_vector(concatenated=False)
unit_ids: np.array
Unit ids
absolute_index: bool, default False
It True, return absolute spike indices. If False, spike indices are relative to the segment.
When a unique spike vector is used, then absolute_index should be True.
When a list of spikes per segment is used, then absolute_index should be False.
Returns
-------
spike_indices: dict[dict]:
Expand All @@ -82,10 +87,16 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array):

num_units = unit_ids.size
spike_indices = {}

total_spikes = 0
for segment_index, spikes in enumerate(spike_vector):
indices = np.arange(spikes.size, dtype=np.int64)
if absolute_index:
indices += total_spikes
total_spikes += spikes.size
unit_indices = np.array(spikes["unit_index"]).astype(np.int64, copy=False)
list_of_spike_indices = vector_to_list_of_spiketrain(indices, unit_indices, num_units)

spike_indices[segment_index] = dict(zip(unit_ids, list_of_spike_indices))

return spike_indices
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _get_data(self, outputs="numpy"):
elif outputs == "by_unit":
unit_ids = self.sorting_analyzer.unit_ids
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False)
spike_indices = spike_vector_to_indices(spike_vector, unit_ids)
spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True)
amplitudes_by_units = {}
for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()):
amplitudes_by_units[segment_index] = {}
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _get_data(self, outputs="numpy"):
elif outputs == "by_unit":
unit_ids = self.sorting_analyzer.unit_ids
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False)
spike_indices = spike_vector_to_indices(spike_vector, unit_ids)
spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True)
spike_locations_by_units = {}
for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()):
spike_locations_by_units[segment_index] = {}
Expand Down

0 comments on commit 2dc0b74

Please sign in to comment.