From c4c4ebb3c23cfa7cccec9b723b412b9f2c2c2e3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 28 Jul 2023 15:35:02 +0200 Subject: [PATCH 1/2] Use spike_vector in `count_num_spikes_per_unit` --- src/spikeinterface/core/basesorting.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 56f46f0a38..b411ef5505 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -278,12 +278,24 @@ def count_num_spikes_per_unit(self): Dictionary with unit_ids as key and number of spikes as values """ num_spikes = {} - for unit_id in self.unit_ids: - n = 0 - for segment_index in range(self.get_num_segments()): - st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - n += st.size - num_spikes[unit_id] = n + + if self._cached_spike_trains is not None: + for unit_id in self.unit_ids: + n = 0 + for segment_index in range(self.get_num_segments()): + st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + n += st.size + num_spikes[unit_id] = n + else: + spike_vector = self.to_spike_vector() + unit_indices, counts = np.unique(spike_vector["unit_index"], return_counts=True) + for unit_index, unit_id in enumerate(self.unit_ids): + if unit_index in unit_indices: + idx = np.argmax(unit_indecex == unit_index) + num_spikes[unit_id] = counts[idx] + else: # This unit has no spikes, hence it's not in the counts array. + num_spikes[unit_id] = 0 + return num_spikes def count_total_num_spikes(self): From f74046b713d85af87afca7af66428c0156571507 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 28 Jul 2023 16:19:15 +0200 Subject: [PATCH 2/2] Typo --- src/spikeinterface/core/basesorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index b411ef5505..52f71c2399 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -291,7 +291,7 @@ def count_num_spikes_per_unit(self): unit_indices, counts = np.unique(spike_vector["unit_index"], return_counts=True) for unit_index, unit_id in enumerate(self.unit_ids): if unit_index in unit_indices: - idx = np.argmax(unit_indecex == unit_index) + idx = np.argmax(unit_indices == unit_index) num_spikes[unit_id] = counts[idx] else: # This unit has no spikes, hence it's not in the counts array. num_spikes[unit_id] = 0