diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index c48ce70147..3b8e9e0a72 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -84,8 +84,8 @@ def __init__( ) # save injected sorting if necessary self.injected_sorting = injected_sorting - # if not self.injected_sorting.check_if_json_serializable(): if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) @@ -181,8 +181,8 @@ def __init__( self.injected_sorting = injected_sorting # save injected sorting if necessary - # if not self.injected_sorting.check_if_json_serializable(): if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 3a7075905e..09a8c8aed1 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -182,7 +182,6 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou def save_to_folder(self, save_folder): for sorting in self.object_list: assert ( - # sorting.check_if_json_serializable() sorting.check_serializablility("json") ), "MultiSortingComparison.save_to_folder() need json serializable sortings" @@ -245,7 +244,6 @@ def __init__( BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = True diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index d87bd617c4..63cf8e894f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -484,11 +484,16 @@ def check_if_dumpable(self): for value in kwargs.values(): # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors if isinstance(value, BaseExtractor): - return value.check_if_dumpable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_dumpable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_dumpable() for k, v in value.items()]) + if not value.check_if_dumpable(): + return False + elif isinstance(value, list): + for v in value: + if isinstance(v, BaseExtractor) and not v.check_if_dumpable(): + return False + elif isinstance(value, dict): + for v in value.values(): + if isinstance(v, BaseExtractor) and not v.check_if_dumpable(): + return False return self._is_dumpable def check_serializablility(self, type="json"): @@ -496,11 +501,16 @@ def check_serializablility(self, type="json"): for value in kwargs.values(): # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors if isinstance(value, BaseExtractor): - return value.check_serializablility(type=type) - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_serializablility(type=type) for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_serializablility(type=type) for k, v in value.items()]) + if not value.check_serializablility(type=type): + return False + elif isinstance(value, list): + for v in value: + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False + elif isinstance(value, dict): + for v in value.values(): + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False return self._serializablility[type] def check_if_json_serializable(self): @@ -513,21 +523,11 @@ def check_if_json_serializable(self): True if the object is json serializable, False otherwise. """ # we keep this for backward compatilibity or not ???? + # is this needed ??? I think no. return self.check_serializablility("json") - # kwargs = self._kwargs - # for value in kwargs.values(): - # # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - # if isinstance(value, BaseExtractor): - # return value.check_if_json_serializable() - # elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - # return all([v.check_if_json_serializable() for v in value]) - # elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - # return all([v.check_if_json_serializable() for k, v in value.items()]) - # return self._is_json_serializable - def check_if_pickle_serializable(self): - # is this needed + # is this needed ??? I think no. return self.check_serializablility("pickle") @staticmethod @@ -596,7 +596,6 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - # assert self.check_if_json_serializable(), "The extractor is not json serializable" assert self.check_serializablility("json"), "The extractor is not json serializable" # Writing paths as relative_to requires recursively expanding the dict @@ -835,7 +834,6 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): # dump provenance provenance_file = folder / f"provenance.json" - # if self.check_if_json_serializable(): if self.check_serializablility("json"): self.dump(provenance_file) else: diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 706054c957..362b598b0b 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1056,6 +1056,8 @@ def __init__( dtype = parent_recording.dtype if parent_recording is not None else templates.dtype BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype) + # Important : self._serializablility is not change here because it will depend on the sorting parents itself. + n_units = len(sorting.unit_ids) assert len(templates) == n_units self.spike_vector = sorting.to_spike_vector() diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index f55b975ddb..5ef955a6eb 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -64,7 +64,6 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list" t_starts = [float(t_start) for t_start in t_starts] - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -129,9 +128,9 @@ def __init__(self, spikes, sampling_frequency, unit_ids): BaseSorting.__init__(self, sampling_frequency, unit_ids) self._is_dumpable = True - # self._is_json_serializable = False self._serializablility["json"] = False - self._serializablility["pickle"] = False + # theorically this should be False but for simplicity make generators simples we still need this. + self._serializablility["pickle"] = True if spikes.size == 0: nseg = 1 @@ -362,7 +361,7 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ BaseSorting.__init__(self, sampling_frequency, unit_ids) self._is_dumpable = True - # self._is_json_serializable = False + self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -523,7 +522,6 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore ) self._is_dumpable = False - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index 38fbef1547..a31edb0dd7 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -183,7 +183,6 @@ def __init__(self, oldapi_recording_extractor): # set _is_dumpable to False to use dumping mechanism of old extractor self._is_dumpable = False - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -271,7 +270,6 @@ def __init__(self, oldapi_sorting_extractor): self.add_sorting_segment(sorting_segment) self._is_dumpable = False - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index 77a5d7d9bf..b716f6b1dd 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -46,16 +46,14 @@ def test_check_if_dumpable(): assert not extractor.check_if_dumpable() -def test_check_if_json_serializable(): +def test_check_if_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) # make a list of dumpable objects - # test_extractor._is_json_serializable = True test_extractor._serializablility["json"] = True extractors_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_json_serializable: print(extractor) - # assert extractor.check_if_json_serializable() assert extractor.check_serializablility("json") # make not dumpable @@ -64,10 +62,9 @@ def test_check_if_json_serializable(): extractors_not_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_not_json_serializable: print(extractor) - # assert not extractor.check_if_json_serializable() assert not extractor.check_serializablility("json") if __name__ == "__main__": test_check_if_dumpable() - test_check_if_json_serializable() + test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index f53b9cf18d..3972c9186c 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -517,6 +517,8 @@ def test_non_json_object(): num_units=5, ) + + print(recording.check_serializablility("pickle")) # recording is not save to keep it in memory sorting = sorting.save() diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 53852bf319..3de1429feb 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -280,17 +280,17 @@ def create( else: relative_to = None - # if recording.check_if_json_serializable(): if recording.check_serializablility("json"): recording.dump(folder / "recording.json", relative_to=relative_to) elif recording.check_serializablility("pickle"): # In this case we loose the relative_to!! - # TODO make sure that we do not dump to pickle a NumpyRecording!!!!! recording.dump(folder / "recording.pickle") - # if sorting.check_if_json_serializable(): if sorting.check_serializablility("json"): sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif sorting.check_serializablility("pickle"): + # In this case we loose the relative_to!! + sorting.dump(folder / "sorting.pickle") else: warn( "Sorting object is not dumpable, which might result in downstream errors for " @@ -895,12 +895,16 @@ def save( (folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8") if self.has_recording(): - # if self.recording.check_if_json_serializable(): if self.recording.check_serializablility("json"): self.recording.dump(folder / "recording.json", relative_to=relative_to) - # if self.sorting.check_if_json_serializable(): + elif self.recording.check_serializablility("pickle"): + self.recording.dump(folder / "recording.pickle") + + if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif self.sorting.check_serializablility("pickle"): + self.sorting.dump(folder / "sorting.pickle", relative_to=relative_to) else: warn( "Sorting object is not dumpable, which might result in downstream errors for " @@ -949,10 +953,10 @@ def save( # write metadata zarr_root.attrs["params"] = check_json(self._params) if self.has_recording(): - if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): rec_dict = self.recording.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["recording"] = check_json(rec_dict) - if self.sorting.check_if_json_serializable(): + if self.sorting.check_serializablility("json"): sort_dict = self.sorting.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["sorting"] = check_json(sort_dict) else: diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 0054fb94d4..6ab1a9afce 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -333,7 +333,6 @@ def correct_motion( ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - # if recording.check_if_json_serializable(): if recording.check_serializablility("json"): recording.dump_to_json(folder / "recording.json")