diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 56f46f0a38..52f71c2399 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -278,12 +278,24 @@ def count_num_spikes_per_unit(self): Dictionary with unit_ids as key and number of spikes as values """ num_spikes = {} - for unit_id in self.unit_ids: - n = 0 - for segment_index in range(self.get_num_segments()): - st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - n += st.size - num_spikes[unit_id] = n + + if self._cached_spike_trains is not None: + for unit_id in self.unit_ids: + n = 0 + for segment_index in range(self.get_num_segments()): + st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + n += st.size + num_spikes[unit_id] = n + else: + spike_vector = self.to_spike_vector() + unit_indices, counts = np.unique(spike_vector["unit_index"], return_counts=True) + for unit_index, unit_id in enumerate(self.unit_ids): + if unit_index in unit_indices: + idx = np.argmax(unit_indices == unit_index) + num_spikes[unit_id] = counts[idx] + else: # This unit has no spikes, hence it's not in the counts array. + num_spikes[unit_id] = 0 + return num_spikes def count_total_num_spikes(self):