From 24373d65fed14f62201d9a7680c7d54982bb6c7d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 31 Oct 2023 18:31:14 +0100 Subject: [PATCH] Add relative_to to dump_to_pickle (#2141) * Add relative_to to dump_to_pickle * Add relative_to to pickle --- src/spikeinterface/core/base.py | 31 +++++++++++++++---- .../core/tests/test_baserecording.py | 21 +++++++++++-- src/spikeinterface/core/waveform_extractor.py | 15 ++++----- src/spikeinterface/sorters/basesorter.py | 5 +-- 4 files changed, 53 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index f188ce7aa6..b51bace55f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -1,4 +1,5 @@ from pathlib import Path +import re from typing import Any, Iterable, List, Optional, Sequence, Union import importlib import warnings @@ -342,7 +343,7 @@ def to_dict( kwargs = self._kwargs if relative_to and not recursive: - raise ValueError("`relative_to` is only posible when `recursive=True`") + raise ValueError("`relative_to` is only possible when `recursive=True`") if recursive: to_dict_kwargs = dict( @@ -571,7 +572,12 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No else: raise ValueError("Dump: file must .json or .pkl") - def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=None, folder_metadata=None) -> None: + def dump_to_json( + self, + file_path: Union[str, Path, None] = None, + relative_to: Union[str, Path, bool, None] = None, + folder_metadata: Union[str, Path, None] = None, + ) -> None: """ Dump recording extractor to json file. The extractor can be re-loaded with load_extractor(json_file) @@ -584,7 +590,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. folder_metadata: str, Path, or None - Folder with files containing additional information (e.g. probe in BaseRecording) and properties. + Folder with files containing additional information (e.g. probe in BaseRecording) and properties """ assert self.check_serializablility("json"), "The extractor is not json serializable" @@ -610,8 +616,9 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non def dump_to_pickle( self, file_path: Union[str, Path, None] = None, + relative_to: Union[str, Path, bool, None] = None, include_properties: bool = True, - folder_metadata=None, + folder_metadata: Union[str, Path, None] = None, ): """ Dump recording extractor to a pickle file. @@ -621,6 +628,9 @@ def dump_to_pickle( ---------- file_path: str Path of the pickle file + relative_to: str, Path, True or None + If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. + This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. include_properties: bool If True, all properties are dumped folder_metadata: str, Path, or None @@ -628,12 +638,21 @@ def dump_to_pickle( """ assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle" + # Writing paths as relative_to requires recursively expanding the dict + if relative_to: + relative_to = Path(file_path).parent if relative_to is True else Path(relative_to) + relative_to = relative_to.resolve().absolute() + # if relative_to is used, the dictionaru needs recursive expansion + recursive = True + else: + recursive = False + dump_dict = self.to_dict( include_annotations=True, include_properties=include_properties, folder_metadata=folder_metadata, - relative_to=None, - recursive=False, + relative_to=relative_to, + recursive=recursive, ) file_path = self._get_file_path(file_path, [".pkl", ".pickle"]) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 38987a58e5..4326cd15aa 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -3,7 +3,7 @@ but check only for BaseRecording general methods. """ import json -import shutil +import pickle from pathlib import Path import pytest import numpy as np @@ -111,7 +111,7 @@ def test_BaseRecording(): rec2 = BaseExtractor.from_dict(d, base_folder=cache_folder) rec3 = load_extractor(d, base_folder=cache_folder) - # dump/load json + # dump/load json - relative to rec.dump_to_json(cache_folder / "test_BaseRecording_rel.json", relative_to=cache_folder) rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) rec3 = load_extractor(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) @@ -128,6 +128,23 @@ def test_BaseRecording(): "/" not in data["kwargs"]["file_paths"][0] ) # Relative to parent folder, so there shouldn't be any '/' in the path. + # dump/load pkl - relative to + rec.dump_to_pickle(cache_folder / "test_BaseRecording_rel.pkl", relative_to=cache_folder) + rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel.pkl", base_folder=cache_folder) + rec3 = load_extractor(cache_folder / "test_BaseRecording_rel.pkl", base_folder=cache_folder) + + # dump/load relative=True + rec.dump_to_pickle(cache_folder / "test_BaseRecording_rel_true.pkl", relative_to=True) + rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel_true.pkl", base_folder=True) + rec3 = load_extractor(cache_folder / "test_BaseRecording_rel_true.pkl", base_folder=True) + check_recordings_equal(rec, rec2, return_scaled=False, check_annotations=True) + check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True) + with open(cache_folder / "test_BaseRecording_rel_true.pkl", "rb") as pkl_file: + data = pickle.load(pkl_file) + assert ( + "/" not in data["kwargs"]["file_paths"][0] + ) # Relative to parent folder, so there shouldn't be any '/' in the path. + # cache to binary folder = cache_folder / "simple_recording" rec.save(format="binary", folder=folder) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 1ce15c3f72..a2b58daa24 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -166,7 +166,7 @@ def load_from_folder( pass elif (folder / "recording.pickle").exists(): try: - recording = load_extractor(folder / "recording.pickle") + recording = load_extractor(folder / "recording.pickle", base_folder=folder) except: pass if recording is None: @@ -177,7 +177,7 @@ def load_from_folder( if (folder / "sorting.json").exists(): sorting = load_extractor(folder / "sorting.json", base_folder=folder) elif (folder / "sorting.pickle").exists(): - sorting = load_extractor(folder / "sorting.pickle") + sorting = load_extractor(folder / "sorting.pickle", base_folder=folder) else: raise FileNotFoundError("load_waveforms() impossible to find the sorting object (json or pickle)") @@ -287,15 +287,12 @@ def create( if recording.check_serializablility("json"): recording.dump(folder / "recording.json", relative_to=relative_to) elif recording.check_serializablility("pickle"): - # In this case we loose the relative_to!! - recording.dump(folder / "recording.pickle") + recording.dump(folder / "recording.pickle", relative_to=relative_to) if sorting.check_serializablility("json"): sorting.dump(folder / "sorting.json", relative_to=relative_to) elif sorting.check_serializablility("pickle"): - # In this case we loose the relative_to!! - # TODO later the dump to pickle should dump the dictionary and so relative could be put back - sorting.dump(folder / "sorting.pickle") + sorting.dump(folder / "sorting.pickle", relative_to=relative_to) else: warn( "Sorting object is not serializable to file, which might result in downstream errors for " @@ -918,12 +915,12 @@ def save( if self.recording.check_serializablility("json"): self.recording.dump(folder / "recording.json", relative_to=relative_to) elif self.recording.check_serializablility("pickle"): - self.recording.dump(folder / "recording.pickle") + self.recording.dump(folder / "recording.pickle", relative_to=relative_to) if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) elif self.sorting.check_serializablility("pickle"): - self.sorting.dump(folder / "sorting.pickle") + self.sorting.dump(folder / "sorting.pickle", relative_to=relative_to) else: warn( "Sorting object is not serializable to file, which might result in downstream errors for " diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 139f15bf12..894918cbc4 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -140,8 +140,9 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo if recording.check_serializablility("json"): recording.dump(rec_file, relative_to=output_folder) elif recording.check_serializablility("pickle"): - recording.dump(output_folder / "spikeinterface_recording.pickle") + recording.dump(output_folder / "spikeinterface_recording.pickle", relative_to=output_folder) else: + # TODO: deprecate and finally remove this after 0.100 d = {"warning": "The recording is not serializable to json"} rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") @@ -203,7 +204,7 @@ def load_recording_from_folder(cls, output_folder, with_warnings=False): else: recording = load_extractor(json_file, base_folder=output_folder) elif pickle_file.exists(): - recording = load_extractor(pickle_file) + recording = load_extractor(pickle_file, base_folder=output_folder) return recording