Skip to content

Commit

Permalink
Merge pull request #1993 from alejoe91/extended-template-metrics
Browse files Browse the repository at this point in the history
Extend and refactor waveform metrics
  • Loading branch information
samuelgarcia authored Oct 16, 2023
2 parents 8a5f544 + 3a1f540 commit 3f11037
Show file tree
Hide file tree
Showing 4 changed files with 712 additions and 77 deletions.
27 changes: 23 additions & 4 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,14 +833,30 @@ def select_units(self, unit_ids, new_folder=None, use_relative_path: bool = Fals
sparsity = ChannelSparsity(mask, unit_ids, self.channel_ids)
else:
sparsity = None
we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity)
we.set_params(**self._params)
if self.has_recording():
we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity)
else:
we = WaveformExtractor(
recording=None,
sorting=sorting,
folder=None,
sparsity=sparsity,
rec_attributes=self._rec_attributes,
allow_unfiltered=True,
)
we._params = self._params
# copy memory objects
if self.has_waveforms():
we._memory_objects = {"wfs_arrays": {}, "sampled_indices": {}}
for unit_id in unit_ids:
we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id]
we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][unit_id]
if self.format == "memory":
we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id]
we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][
unit_id
]
else:
we._memory_objects["wfs_arrays"][unit_id] = self.get_waveforms(unit_id)
we._memory_objects["sampled_indices"][unit_id] = self.get_sampled_indices(unit_id)

# finally select extensions data
for ext_name in self.get_available_extension_names():
Expand Down Expand Up @@ -2016,6 +2032,9 @@ def set_params(self, **params):
params = self._set_params(**params)
self._params = params

if self.waveform_extractor.is_read_only():
return

params_to_save = params.copy()
if "sparsity" in params and params["sparsity"] is not None:
assert isinstance(
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .template_metrics import (
TemplateMetricsCalculator,
compute_template_metrics,
calculate_template_metrics,
get_template_metric_names,
)

Expand Down
Loading

0 comments on commit 3f11037

Please sign in to comment.