Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 28, 2023
1 parent 00f91eb commit 7ba84ad
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 128 deletions.
24 changes: 20 additions & 4 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,14 +811,30 @@ def select_units(self, unit_ids, new_folder=None, use_relative_path: bool = Fals
sparsity = ChannelSparsity(mask, unit_ids, self.channel_ids)
else:
sparsity = None
we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity)
we.set_params(**self._params)
if self.has_recording():
we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity)
else:
we = WaveformExtractor(
recording=None,
sorting=sorting,
folder=None,
sparsity=sparsity,
rec_attributes=self._rec_attributes,
allow_unfiltered=True,
)
we._params = self._params
# copy memory objects
if self.has_waveforms():
we._memory_objects = {"wfs_arrays": {}, "sampled_indices": {}}
for unit_id in unit_ids:
we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id]
we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][unit_id]
if self.format == "memory":
we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id]
we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][
unit_id
]
else:
we._memory_objects["wfs_arrays"][unit_id] = self.get_waveforms(unit_id)
we._memory_objects["sampled_indices"][unit_id] = self.get_sampled_indices(unit_id)

# finally select extensions data
for ext_name in self.get_available_extension_names():
Expand Down
Loading

0 comments on commit 7ba84ad

Please sign in to comment.