From 815f6053fb438ae7f5eb04462ac31aad17200aac Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 28 Jun 2024 15:27:28 -0600 Subject: [PATCH] accept paths in jsonification --- src/spikeinterface/core/core_tools.py | 3 ++ .../core/tests/test_jsonification.py | 29 ++++++++++++++----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 066ab58d8c..d4701343af 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -98,6 +98,9 @@ def default(self, obj): if isinstance(obj, BaseExtractor): return obj.to_dict() + if isinstance(obj, Path): + return str(obj) + # The base-class handles the assertion return super().default(obj) diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index 4417ea342f..316dac3abc 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -7,11 +7,7 @@ from spikeinterface.core.core_tools import SIJsonEncoder from spikeinterface.core.generate import generate_recording, generate_sorting - -@pytest.fixture(scope="module") -def numpy_generated_recording(): - recording = generate_recording() - return recording +from pathlib import Path @pytest.fixture(scope="module") @@ -124,8 +120,25 @@ def test_numpy_dtype_alises_encoding(): json.dumps(np.float32, cls=SIJsonEncoder) -def test_recording_encoding(numpy_generated_recording): - recording = numpy_generated_recording +def test_path_encoding(tmp_path): + + temporary_path = tmp_path / "a_path_for_this_test" + + json.dumps(temporary_path, cls=SIJsonEncoder) + + +def test_path_as_annotation(tmp_path): + temporary_path = tmp_path / "a_path_for_this_test" + + recording = generate_recording() + recording.annotate(path=temporary_path) + + json.dumps(recording, cls=SIJsonEncoder) + + +def test_recording_encoding(): + recording = generate_recording() + json.dumps(recording, cls=SIJsonEncoder) @@ -200,4 +213,4 @@ def test_encoding_numpy_scalars_within_nested_extractors_dict(nested_extractor_d if __name__ == "__main__": nested_extractor = nested_extractor() - test_encoding_numpy_scalars_within_nested_extractors(nested_extractor_) + test_encoding_numpy_scalars_within_nested_extractors(nested_extractor)