Skip to content

Commit

Permalink
Merge pull request #2559 from zm711/zarr-json-serializability
Browse files Browse the repository at this point in the history
Add `check_json` to `zarr` for `SortingAnalyzer` for sorting_provenance file writing
  • Loading branch information
alejoe91 authored Mar 11, 2024
2 parents e0a7ad0 + 6d2ffe3 commit 6d382a2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@
# SortingAnalyzer and AnalyzerExtension
from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, create_sorting_analyzer, load_sorting_analyzer
from .analyzer_extension_core import (
SelectRandomSpikes,
compute_select_random_spikes,
ComputeWaveforms,
compute_waveforms,
ComputeTemplates,
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""
Implement AnalyzerExtension that are essential and imported in core
* SelectRandomSpikes
* ComputeWaveforms
* ComputeTemplates
Theses two classes replace the WaveformExtractor
It also implement:
* ComputeFastTemplates which is equivalent but without extacting waveforms.
* ComputeNoiseLevels which is very convinient to have
* ComputeNoiseLevels which is very convenient to have
"""

import numpy as np
Expand Down Expand Up @@ -105,6 +106,7 @@ def get_selected_indices_in_spike_train(self, unit_id, segment_index):
return selected_spikes_in_spike_train


compute_select_random_spikes = SelectRandomSpikes.function_factory()
register_result_extension(SelectRandomSpikes)


Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes):
# sorting provenance
sort_dict = sorting.to_dict(relative_to=folder, recursive=True)
if sorting.check_serializability("json"):
zarr_sort = np.array([sort_dict], dtype=object)
zarr_sort = np.array([check_json(sort_dict)], dtype=object)
zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON())
elif sorting.check_serializability("pickle"):
zarr_sort = np.array([sort_dict], dtype=object)
Expand Down Expand Up @@ -507,7 +507,7 @@ def load_from_zarr(cls, folder, recording=None):
# sparsity
if "sparsity_mask" in zarr_root.attrs:
# sparsity = zarr_root.attrs["sparsity"]
sparsity = ChannelSparsity(zarr_root["sparsity_mask"], self.unit_ids, rec_attributes["channel_ids"])
sparsity = ChannelSparsity(zarr_root["sparsity_mask"], cls.unit_ids, rec_attributes["channel_ids"])
else:
sparsity = None

Expand Down

0 comments on commit 6d382a2

Please sign in to comment.