Skip to content

Commit

Permalink
Merge pull request #3433 from alejoe91/postprocessing-tmp-recording
Browse files Browse the repository at this point in the history
Fix compute analyzer pipeline with tmp recording
  • Loading branch information
alejoe91 authored Sep 23, 2024
2 parents a874f2a + a071605 commit 50eadce
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
24 changes: 14 additions & 10 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def __repr__(self) -> str:
txt += " - sparse"
if self.has_recording():
txt += " - has recording"
if self.has_temporary_recording():
txt += " - has temporary recording"
ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys())
txt += "\n" + ext_txt
return txt
Expand Down Expand Up @@ -350,7 +352,7 @@ def create_memory(cls, sorting, recording, sparsity, return_scaled, rec_attribut
def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes):
# used by create and save_as

assert recording is not None, "To create a SortingAnalyzer you need recording not None"
assert recording is not None, "To create a SortingAnalyzer you need to specify the recording"

folder = Path(folder)
if folder.is_dir():
Expand Down Expand Up @@ -1221,7 +1223,7 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar
extensions[ext_name] = ext_params
self.compute_several_extensions(extensions=extensions, save=save, verbose=verbose, **job_kwargs)
else:
raise ValueError("SortingAnalyzer.compute() need str, dict or list")
raise ValueError("SortingAnalyzer.compute() needs a str, dict or list")

def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs) -> "AnalyzerExtension":
"""
Expand Down Expand Up @@ -1355,7 +1357,9 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job

for extension_name, extension_params in extensions_with_pipeline.items():
extension_class = get_extension_class(extension_name)
assert self.has_recording(), f"Extension {extension_name} need the recording"
assert (
self.has_recording() or self.has_temporary_recording()
), f"Extension {extension_name} requires the recording"

for variable_name in extension_class.nodepipeline_variables:
result_routage.append((extension_name, variable_name))
Expand Down Expand Up @@ -1603,17 +1607,17 @@ def _sort_extensions_by_dependency(extensions):
def _get_children_dependencies(extension_name):
"""
Extension classes have a `depend_on` attribute to declare on which class they
depend. For instance "templates" depend on "waveforms". "waveforms depends on "random_spikes".
depend on. For instance "templates" depends on "waveforms". "waveforms" depends on "random_spikes".
This function is making the reverse way : get all children that depend of a
This function is going the opposite way: it finds all children that depend on a
particular extension.
This is recursive so this includes : children and so grand children and great grand children
The implementation is recursive so that the output includes children, grand children, great grand children, etc.
This function is usefull for deleting on recompute.
For instance recompute the "waveforms" need to delete "template"
This make sens if "ms_before" is change in "waveforms" because the template also depends
on this parameters.
This function is useful for deleting existing extensions on recompute.
For instance, recomputing the "waveforms" needs to delete the "templates", since the latter depends on the former.
For this particular example, if we change the "ms_before" parameter of the "waveforms", also the "templates" will
require recomputation as this parameter is inherited.
"""
names = []
children = _extension_children[extension_name]
Expand Down
17 changes: 8 additions & 9 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,12 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):

job_kwargs = fix_job_kwargs(job_kwargs)
p = self.params
we = self.sorting_analyzer
sorting = we.sorting
sorting_analyzer = self.sorting_analyzer
sorting = sorting_analyzer.sorting
assert (
we.has_recording()
), "To compute PCA projections for all spikes, the waveform extractor needs the recording"
recording = we.recording
sorting_analyzer.has_recording() or sorting_analyzer.has_temporary_recording()
), "To compute PCA projections for all spikes, the sorting analyzer needs the recording"
recording = sorting_analyzer.recording

# assert sorting.get_num_segments() == 1
assert p["mode"] in ("by_channel_local", "by_channel_global")
Expand All @@ -374,8 +374,9 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):

sparsity = self.sorting_analyzer.sparsity
if sparsity is None:
sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids}
max_channels_per_template = we.get_num_channels()
num_channels = recording.get_num_channels()
sparse_channels_indices = {unit_id: np.arange(num_channels) for unit_id in sorting_analyzer.unit_ids}
max_channels_per_template = num_channels
else:
sparse_channels_indices = sparsity.unit_id_to_channel_indices
max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()])
Expand Down Expand Up @@ -449,9 +450,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar):
return pca_models

def _fit_by_channel_global(self, progress_bar):
# we = self.sorting_analyzer
p = self.params
# unit_ids = we.unit_ids
unit_ids = self.sorting_analyzer.unit_ids

# there is one unique PCA accross channels
Expand Down

0 comments on commit 50eadce

Please sign in to comment.