From eada92a7aae01ddf14b872dcec47ebd50c3d8591 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 19 Apr 2024 11:20:09 +0200 Subject: [PATCH] Fix performance issue for aggregate channels --- .../core/channelsaggregationrecording.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 3b764725a2..9ec7ca8d4d 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -127,7 +127,7 @@ def get_traces( self, start_frame: int | None = None, end_frame: int | None = None, - channel_indices: list | None = None, + channel_indices: list | slice | None = None, ) -> np.ndarray: return_all_channels = False if channel_indices is None: @@ -142,11 +142,17 @@ def get_traces( # in case channel_indices is slice, it has step 1 step = channel_indices.step if channel_indices.step is not None else 1 channel_indices = list(range(channel_indices.start, channel_indices.stop, step)) + recording_id_channels_map = {} for channel_idx in channel_indices: - segment = self._parent_segments[self._channel_map[channel_idx]["recording_id"]] + recording_id = self._channel_map[channel_idx]["recording_id"] channel_index_recording = self._channel_map[channel_idx]["channel_index"] + if recording_id not in recording_id_channels_map: + recording_id_channels_map[recording_id] = [] + recording_id_channels_map[recording_id].append(channel_index_recording) + for recording_id, channel_indices_recording in recording_id_channels_map.items(): + segment = self._parent_segments[recording_id] traces_recording = segment.get_traces( - channel_indices=[channel_index_recording], start_frame=start_frame, end_frame=end_frame + channel_indices=channel_indices_recording, start_frame=start_frame, end_frame=end_frame ) traces.append(traces_recording) else: