Skip to content

Commit

Permalink
Merge branch 'main' into spike_retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Sep 1, 2023
2 parents b2e737e + ee2237b commit aa09cc3
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,24 @@ def count_num_spikes_per_unit(self):
Dictionary with unit_ids as key and number of spikes as values
"""
num_spikes = {}
for unit_id in self.unit_ids:
n = 0
for segment_index in range(self.get_num_segments()):
st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index)
n += st.size
num_spikes[unit_id] = n

if self._cached_spike_trains is not None:
for unit_id in self.unit_ids:
n = 0
for segment_index in range(self.get_num_segments()):
st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index)
n += st.size
num_spikes[unit_id] = n
else:
spike_vector = self.to_spike_vector()
unit_indices, counts = np.unique(spike_vector["unit_index"], return_counts=True)
for unit_index, unit_id in enumerate(self.unit_ids):
if unit_index in unit_indices:
idx = np.argmax(unit_indices == unit_index)
num_spikes[unit_id] = counts[idx]
else: # This unit has no spikes, hence it's not in the counts array.
num_spikes[unit_id] = 0

return num_spikes

def count_total_num_spikes(self):
Expand Down

0 comments on commit aa09cc3

Please sign in to comment.