Skip to content

Commit

Permalink
Fix serializability of InjectDriftingTemplatesRecording
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Jul 3, 2024
1 parent 2af38a3 commit f46f9c9
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import weakref
import json
import pickle
import os
import random
import string
from packaging.version import parse
Expand Down Expand Up @@ -928,13 +927,14 @@ def save_to_folder(
folder.mkdir(parents=True, exist_ok=False)

# dump provenance
provenance_file = folder / f"provenance.json"
if self.check_serializability("json"):
provenance_file = folder / f"provenance.json"
self.dump(provenance_file)
elif self.check_serializability("pickle"):
provenance_file = folder / f"provenance.pkl"
self.dump(provenance_file)
else:
provenance_file.write_text(
json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8"
)
warnings.warn("The extractor is not serializable to file. The provenance will not be saved.")

self.save_metadata_to_folder(folder)

Expand Down Expand Up @@ -1001,7 +1001,6 @@ def save_to_zarr(
cached: ZarrExtractor
Saved copy of the extractor.
"""
import zarr
from .zarrextractors import read_zarr

save_kwargs.pop("format", None)
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,6 +1738,8 @@ def __init__(
)
self.add_recording_segment(recording_segment)

# to discuss: maybe we could set json serializability to False always
# because templates could be large!
if not sorting.check_serializability("json"):
self._serializability["json"] = False
if parent_recording is not None:
Expand Down
3 changes: 3 additions & 0 deletions src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,9 @@ def __init__(

self.set_probe(drifting_templates.probe, in_place=True)

# templates are too large, we don't serialize them to JSON
self._serializability["json"] = False

self._kwargs = {
"sorting": sorting,
"drifting_templates": drifting_templates,
Expand Down

0 comments on commit f46f9c9

Please sign in to comment.