Skip to content

Commit

Permalink
Merge pull request #2736 from alejoe91/fix-aggregate-channels
Browse files Browse the repository at this point in the history
Fix performance issue for `aggregate_channels`
  • Loading branch information
alejoe91 authored Apr 19, 2024
2 parents 9811b8d + eada92a commit a823e15
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/spikeinterface/core/channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_traces(
self,
start_frame: int | None = None,
end_frame: int | None = None,
channel_indices: list | None = None,
channel_indices: list | slice | None = None,
) -> np.ndarray:
return_all_channels = False
if channel_indices is None:
Expand All @@ -142,11 +142,17 @@ def get_traces(
# in case channel_indices is slice, it has step 1
step = channel_indices.step if channel_indices.step is not None else 1
channel_indices = list(range(channel_indices.start, channel_indices.stop, step))
recording_id_channels_map = {}
for channel_idx in channel_indices:
segment = self._parent_segments[self._channel_map[channel_idx]["recording_id"]]
recording_id = self._channel_map[channel_idx]["recording_id"]
channel_index_recording = self._channel_map[channel_idx]["channel_index"]
if recording_id not in recording_id_channels_map:
recording_id_channels_map[recording_id] = []
recording_id_channels_map[recording_id].append(channel_index_recording)
for recording_id, channel_indices_recording in recording_id_channels_map.items():
segment = self._parent_segments[recording_id]
traces_recording = segment.get_traces(
channel_indices=[channel_index_recording], start_frame=start_frame, end_frame=end_frame
channel_indices=channel_indices_recording, start_frame=start_frame, end_frame=end_frame
)
traces.append(traces_recording)
else:
Expand Down

0 comments on commit a823e15

Please sign in to comment.