From c242446086bca416766912fab079e17db100bd03 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 12:24:27 +0200 Subject: [PATCH] add pickability --- .../core/tests/test_template_class.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index defad77e00..b864421824 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -1,5 +1,5 @@ from pathlib import Path - +import pickle import numpy as np import pytest @@ -58,3 +58,23 @@ def test_numpy_like_behavior(): # Test ufuncs that return non-ndarray results assert np.all(np.greater(templates, -1)) assert not np.any(np.less(templates, 0)) + + +def test_pickle(): + num_units = 2 + num_samples = 4 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + templates = Templates(templates_array=templates_array) + + # Serialize and deserialize the object + serialized = pickle.dumps(templates) + deserialized = pickle.loads(serialized) + + assert np.array_equal(templates.templates_array, deserialized.templates_array) + assert templates.sparsity == deserialized.sparsity + assert templates.num_units == deserialized.num_units + assert templates.num_samples == deserialized.num_samples + assert templates.num_channels == deserialized.num_channels