Skip to content

Commit

Permalink
Add relative_to to dump_to_pickle (#2141)
Browse files Browse the repository at this point in the history
* Add relative_to to dump_to_pickle

* Add relative_to to pickle
  • Loading branch information
alejoe91 authored Oct 31, 2023
1 parent 886530a commit 24373d6
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 19 deletions.
31 changes: 25 additions & 6 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
import re
from typing import Any, Iterable, List, Optional, Sequence, Union
import importlib
import warnings
Expand Down Expand Up @@ -342,7 +343,7 @@ def to_dict(
kwargs = self._kwargs

if relative_to and not recursive:
raise ValueError("`relative_to` is only posible when `recursive=True`")
raise ValueError("`relative_to` is only possible when `recursive=True`")

if recursive:
to_dict_kwargs = dict(
Expand Down Expand Up @@ -571,7 +572,12 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No
else:
raise ValueError("Dump: file must .json or .pkl")

def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=None, folder_metadata=None) -> None:
def dump_to_json(
self,
file_path: Union[str, Path, None] = None,
relative_to: Union[str, Path, bool, None] = None,
folder_metadata: Union[str, Path, None] = None,
) -> None:
"""
Dump recording extractor to json file.
The extractor can be re-loaded with load_extractor(json_file)
Expand All @@ -584,7 +590,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non
If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder.
This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute.
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties.
Folder with files containing additional information (e.g. probe in BaseRecording) and properties
"""
assert self.check_serializablility("json"), "The extractor is not json serializable"

Expand All @@ -610,8 +616,9 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non
def dump_to_pickle(
self,
file_path: Union[str, Path, None] = None,
relative_to: Union[str, Path, bool, None] = None,
include_properties: bool = True,
folder_metadata=None,
folder_metadata: Union[str, Path, None] = None,
):
"""
Dump recording extractor to a pickle file.
Expand All @@ -621,19 +628,31 @@ def dump_to_pickle(
----------
file_path: str
Path of the pickle file
relative_to: str, Path, True or None
If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder.
This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute.
include_properties: bool
If True, all properties are dumped
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties.
"""
assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle"

# Writing paths as relative_to requires recursively expanding the dict
if relative_to:
relative_to = Path(file_path).parent if relative_to is True else Path(relative_to)
relative_to = relative_to.resolve().absolute()
# if relative_to is used, the dictionaru needs recursive expansion
recursive = True
else:
recursive = False

dump_dict = self.to_dict(
include_annotations=True,
include_properties=include_properties,
folder_metadata=folder_metadata,
relative_to=None,
recursive=False,
relative_to=relative_to,
recursive=recursive,
)
file_path = self._get_file_path(file_path, [".pkl", ".pickle"])

Expand Down
21 changes: 19 additions & 2 deletions src/spikeinterface/core/tests/test_baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
but check only for BaseRecording general methods.
"""
import json
import shutil
import pickle
from pathlib import Path
import pytest
import numpy as np
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_BaseRecording():
rec2 = BaseExtractor.from_dict(d, base_folder=cache_folder)
rec3 = load_extractor(d, base_folder=cache_folder)

# dump/load json
# dump/load json - relative to
rec.dump_to_json(cache_folder / "test_BaseRecording_rel.json", relative_to=cache_folder)
rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder)
rec3 = load_extractor(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder)
Expand All @@ -128,6 +128,23 @@ def test_BaseRecording():
"/" not in data["kwargs"]["file_paths"][0]
) # Relative to parent folder, so there shouldn't be any '/' in the path.

# dump/load pkl - relative to
rec.dump_to_pickle(cache_folder / "test_BaseRecording_rel.pkl", relative_to=cache_folder)
rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel.pkl", base_folder=cache_folder)
rec3 = load_extractor(cache_folder / "test_BaseRecording_rel.pkl", base_folder=cache_folder)

# dump/load relative=True
rec.dump_to_pickle(cache_folder / "test_BaseRecording_rel_true.pkl", relative_to=True)
rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel_true.pkl", base_folder=True)
rec3 = load_extractor(cache_folder / "test_BaseRecording_rel_true.pkl", base_folder=True)
check_recordings_equal(rec, rec2, return_scaled=False, check_annotations=True)
check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True)
with open(cache_folder / "test_BaseRecording_rel_true.pkl", "rb") as pkl_file:
data = pickle.load(pkl_file)
assert (
"/" not in data["kwargs"]["file_paths"][0]
) # Relative to parent folder, so there shouldn't be any '/' in the path.

# cache to binary
folder = cache_folder / "simple_recording"
rec.save(format="binary", folder=folder)
Expand Down
15 changes: 6 additions & 9 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def load_from_folder(
pass
elif (folder / "recording.pickle").exists():
try:
recording = load_extractor(folder / "recording.pickle")
recording = load_extractor(folder / "recording.pickle", base_folder=folder)
except:
pass
if recording is None:
Expand All @@ -177,7 +177,7 @@ def load_from_folder(
if (folder / "sorting.json").exists():
sorting = load_extractor(folder / "sorting.json", base_folder=folder)
elif (folder / "sorting.pickle").exists():
sorting = load_extractor(folder / "sorting.pickle")
sorting = load_extractor(folder / "sorting.pickle", base_folder=folder)
else:
raise FileNotFoundError("load_waveforms() impossible to find the sorting object (json or pickle)")

Expand Down Expand Up @@ -287,15 +287,12 @@ def create(
if recording.check_serializablility("json"):
recording.dump(folder / "recording.json", relative_to=relative_to)
elif recording.check_serializablility("pickle"):
# In this case we loose the relative_to!!
recording.dump(folder / "recording.pickle")
recording.dump(folder / "recording.pickle", relative_to=relative_to)

if sorting.check_serializablility("json"):
sorting.dump(folder / "sorting.json", relative_to=relative_to)
elif sorting.check_serializablility("pickle"):
# In this case we loose the relative_to!!
# TODO later the dump to pickle should dump the dictionary and so relative could be put back
sorting.dump(folder / "sorting.pickle")
sorting.dump(folder / "sorting.pickle", relative_to=relative_to)
else:
warn(
"Sorting object is not serializable to file, which might result in downstream errors for "
Expand Down Expand Up @@ -918,12 +915,12 @@ def save(
if self.recording.check_serializablility("json"):
self.recording.dump(folder / "recording.json", relative_to=relative_to)
elif self.recording.check_serializablility("pickle"):
self.recording.dump(folder / "recording.pickle")
self.recording.dump(folder / "recording.pickle", relative_to=relative_to)

if self.sorting.check_serializablility("json"):
self.sorting.dump(folder / "sorting.json", relative_to=relative_to)
elif self.sorting.check_serializablility("pickle"):
self.sorting.dump(folder / "sorting.pickle")
self.sorting.dump(folder / "sorting.pickle", relative_to=relative_to)
else:
warn(
"Sorting object is not serializable to file, which might result in downstream errors for "
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/sorters/basesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo
if recording.check_serializablility("json"):
recording.dump(rec_file, relative_to=output_folder)
elif recording.check_serializablility("pickle"):
recording.dump(output_folder / "spikeinterface_recording.pickle")
recording.dump(output_folder / "spikeinterface_recording.pickle", relative_to=output_folder)
else:
# TODO: deprecate and finally remove this after 0.100
d = {"warning": "The recording is not serializable to json"}
rec_file.write_text(json.dumps(d, indent=4), encoding="utf8")

Expand Down Expand Up @@ -203,7 +204,7 @@ def load_recording_from_folder(cls, output_folder, with_warnings=False):
else:
recording = load_extractor(json_file, base_folder=output_folder)
elif pickle_file.exists():
recording = load_extractor(pickle_file)
recording = load_extractor(pickle_file, base_folder=output_folder)

return recording

Expand Down

0 comments on commit 24373d6

Please sign in to comment.