Skip to content

Commit

Permalink
More check and clean for check_if_serializable()
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Sep 20, 2023
1 parent 9a97e68 commit 3f4e182
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 48 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/comparison/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def __init__(
)
# save injected sorting if necessary
self.injected_sorting = injected_sorting
# if not self.injected_sorting.check_if_json_serializable():
if not self.injected_sorting.check_serializablility("json"):
# TODO later : also use pickle
assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object"
self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder)

Expand Down Expand Up @@ -181,8 +181,8 @@ def __init__(
self.injected_sorting = injected_sorting

# save injected sorting if necessary
# if not self.injected_sorting.check_if_json_serializable():
if not self.injected_sorting.check_serializablility("json"):
# TODO later : also use pickle
assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object"
self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder)

Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/comparison/multicomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou
def save_to_folder(self, save_folder):
for sorting in self.object_list:
assert (
# sorting.check_if_json_serializable()
sorting.check_serializablility("json")
), "MultiSortingComparison.save_to_folder() need json serializable sortings"

Expand Down Expand Up @@ -245,7 +244,6 @@ def __init__(

BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids)

# self._is_json_serializable = False
self._serializablility["json"] = False
self._serializablility["pickle"] = True

Expand Down
46 changes: 22 additions & 24 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,23 +484,33 @@ def check_if_dumpable(self):
for value in kwargs.values():
# here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
if isinstance(value, BaseExtractor):
return value.check_if_dumpable()
elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor):
return all([v.check_if_dumpable() for v in value])
elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor):
return all([v.check_if_dumpable() for k, v in value.items()])
if not value.check_if_dumpable():
return False
elif isinstance(value, list):
for v in value:
if isinstance(v, BaseExtractor) and not v.check_if_dumpable():
return False
elif isinstance(value, dict):
for v in value.values():
if isinstance(v, BaseExtractor) and not v.check_if_dumpable():
return False
return self._is_dumpable

def check_serializablility(self, type="json"):
kwargs = self._kwargs
for value in kwargs.values():
# here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
if isinstance(value, BaseExtractor):
return value.check_serializablility(type=type)
elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor):
return all([v.check_serializablility(type=type) for v in value])
elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor):
return all([v.check_serializablility(type=type) for k, v in value.items()])
if not value.check_serializablility(type=type):
return False
elif isinstance(value, list):
for v in value:
if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type):
return False
elif isinstance(value, dict):
for v in value.values():
if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type):
return False
return self._serializablility[type]

def check_if_json_serializable(self):
Expand All @@ -513,21 +523,11 @@ def check_if_json_serializable(self):
True if the object is json serializable, False otherwise.
"""
# we keep this for backward compatilibity or not ????
# is this needed ??? I think no.
return self.check_serializablility("json")

# kwargs = self._kwargs
# for value in kwargs.values():
# # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
# if isinstance(value, BaseExtractor):
# return value.check_if_json_serializable()
# elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor):
# return all([v.check_if_json_serializable() for v in value])
# elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor):
# return all([v.check_if_json_serializable() for k, v in value.items()])
# return self._is_json_serializable

def check_if_pickle_serializable(self):
# is this needed
# is this needed ??? I think no.
return self.check_serializablility("pickle")

@staticmethod
Expand Down Expand Up @@ -596,7 +596,6 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties.
"""
# assert self.check_if_json_serializable(), "The extractor is not json serializable"
assert self.check_serializablility("json"), "The extractor is not json serializable"

# Writing paths as relative_to requires recursively expanding the dict
Expand Down Expand Up @@ -835,7 +834,6 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs):

# dump provenance
provenance_file = folder / f"provenance.json"
# if self.check_if_json_serializable():
if self.check_serializablility("json"):
self.dump(provenance_file)
else:
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,8 @@ def __init__(
dtype = parent_recording.dtype if parent_recording is not None else templates.dtype
BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype)

# Important : self._serializablility is not change here because it will depend on the sorting parents itself.

n_units = len(sorting.unit_ids)
assert len(templates) == n_units
self.spike_vector = sorting.to_spike_vector()
Expand Down
8 changes: 3 additions & 5 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list"
t_starts = [float(t_start) for t_start in t_starts]

# self._is_json_serializable = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

Expand Down Expand Up @@ -129,9 +128,9 @@ def __init__(self, spikes, sampling_frequency, unit_ids):
BaseSorting.__init__(self, sampling_frequency, unit_ids)

self._is_dumpable = True
# self._is_json_serializable = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False
# theorically this should be False but for simplicity make generators simples we still need this.
self._serializablility["pickle"] = True

if spikes.size == 0:
nseg = 1
Expand Down Expand Up @@ -362,7 +361,7 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_

BaseSorting.__init__(self, sampling_frequency, unit_ids)
self._is_dumpable = True
# self._is_json_serializable = False

self._serializablility["json"] = False
self._serializablility["pickle"] = False

Expand Down Expand Up @@ -523,7 +522,6 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore
)

self._is_dumpable = False
# self._is_json_serializable = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/core/old_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def __init__(self, oldapi_recording_extractor):

# set _is_dumpable to False to use dumping mechanism of old extractor
self._is_dumpable = False
# self._is_json_serializable = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

Expand Down Expand Up @@ -271,7 +270,6 @@ def __init__(self, oldapi_sorting_extractor):
self.add_sorting_segment(sorting_segment)

self._is_dumpable = False
# self._is_json_serializable = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

Expand Down
7 changes: 2 additions & 5 deletions src/spikeinterface/core/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,14 @@ def test_check_if_dumpable():
assert not extractor.check_if_dumpable()


def test_check_if_json_serializable():
def test_check_if_serializable():
test_extractor = generate_recording(seed=0, durations=[2])

# make a list of dumpable objects
# test_extractor._is_json_serializable = True
test_extractor._serializablility["json"] = True
extractors_json_serializable = make_nested_extractors(test_extractor)
for extractor in extractors_json_serializable:
print(extractor)
# assert extractor.check_if_json_serializable()
assert extractor.check_serializablility("json")

# make not dumpable
Expand All @@ -64,10 +62,9 @@ def test_check_if_json_serializable():
extractors_not_json_serializable = make_nested_extractors(test_extractor)
for extractor in extractors_not_json_serializable:
print(extractor)
# assert not extractor.check_if_json_serializable()
assert not extractor.check_serializablility("json")


if __name__ == "__main__":
test_check_if_dumpable()
test_check_if_json_serializable()
test_check_if_serializable()
2 changes: 2 additions & 0 deletions src/spikeinterface/core/tests/test_waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,8 @@ def test_non_json_object():
num_units=5,
)


print(recording.check_serializablility("pickle"))
# recording is not save to keep it in memory
sorting = sorting.save()

Expand Down
18 changes: 11 additions & 7 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,17 @@ def create(
else:
relative_to = None

# if recording.check_if_json_serializable():
if recording.check_serializablility("json"):
recording.dump(folder / "recording.json", relative_to=relative_to)
elif recording.check_serializablility("pickle"):
# In this case we loose the relative_to!!
# TODO make sure that we do not dump to pickle a NumpyRecording!!!!!
recording.dump(folder / "recording.pickle")

# if sorting.check_if_json_serializable():
if sorting.check_serializablility("json"):
sorting.dump(folder / "sorting.json", relative_to=relative_to)
elif sorting.check_serializablility("pickle"):
# In this case we loose the relative_to!!
sorting.dump(folder / "sorting.pickle")
else:
warn(
"Sorting object is not dumpable, which might result in downstream errors for "
Expand Down Expand Up @@ -895,12 +895,16 @@ def save(
(folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8")

if self.has_recording():
# if self.recording.check_if_json_serializable():
if self.recording.check_serializablility("json"):
self.recording.dump(folder / "recording.json", relative_to=relative_to)
# if self.sorting.check_if_json_serializable():
elif self.recording.check_serializablility("pickle"):
self.recording.dump(folder / "recording.pickle")


if self.sorting.check_serializablility("json"):
self.sorting.dump(folder / "sorting.json", relative_to=relative_to)
elif self.sorting.check_serializablility("pickle"):
self.sorting.dump(folder / "sorting.pickle", relative_to=relative_to)
else:
warn(
"Sorting object is not dumpable, which might result in downstream errors for "
Expand Down Expand Up @@ -949,10 +953,10 @@ def save(
# write metadata
zarr_root.attrs["params"] = check_json(self._params)
if self.has_recording():
if self.recording.check_if_json_serializable():
if self.recording.check_serializablility("json"):
rec_dict = self.recording.to_dict(relative_to=relative_to, recursive=True)
zarr_root.attrs["recording"] = check_json(rec_dict)
if self.sorting.check_if_json_serializable():
if self.sorting.check_serializablility("json"):
sort_dict = self.sorting.to_dict(relative_to=relative_to, recursive=True)
zarr_root.attrs["sorting"] = check_json(sort_dict)
else:
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@ def correct_motion(
)
(folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8")
(folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8")
# if recording.check_if_json_serializable():
if recording.check_serializablility("json"):
recording.dump_to_json(folder / "recording.json")

Expand Down

0 comments on commit 3f4e182

Please sign in to comment.