Skip to content

Commit

Permalink
Merge pull request #1895 from DradeAW/faster_count_spikes
Browse files Browse the repository at this point in the history
Use spike_vector in `count_num_spikes_per_unit`
  • Loading branch information
samuelgarcia authored Sep 1, 2023
2 parents 2e549a9 + de925f3 commit ee2237b
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,24 @@ def count_num_spikes_per_unit(self):
Dictionary with unit_ids as key and number of spikes as values
"""
num_spikes = {}
for unit_id in self.unit_ids:
n = 0
for segment_index in range(self.get_num_segments()):
st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index)
n += st.size
num_spikes[unit_id] = n

if self._cached_spike_trains is not None:
for unit_id in self.unit_ids:
n = 0
for segment_index in range(self.get_num_segments()):
st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index)
n += st.size
num_spikes[unit_id] = n
else:
spike_vector = self.to_spike_vector()
unit_indices, counts = np.unique(spike_vector["unit_index"], return_counts=True)
for unit_index, unit_id in enumerate(self.unit_ids):
if unit_index in unit_indices:
idx = np.argmax(unit_indices == unit_index)
num_spikes[unit_id] = counts[idx]
else: # This unit has no spikes, hence it's not in the counts array.
num_spikes[unit_id] = 0

return num_spikes

def count_total_num_spikes(self):
Expand Down

0 comments on commit ee2237b

Please sign in to comment.