diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 3542879df4..2906692902 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -46,9 +46,10 @@ def to_dict(self): # Construct the object from a dictionary @classmethod def from_dict(cls, data): + sparsity = ChannelSparsity.from_dict(data["sparsity"]) if data["sparsity"] is not None else None return cls( templates_array=np.array(data["templates_array"]), - sparsity=ChannelSparsity(data["sparsity"]), # Assuming you can reconstruct a ChannelSparsity from a string + sparsity=sparsity, ) def to_json(self): diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index e9e3f60730..62673906ab 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -6,6 +6,7 @@ import pytest from spikeinterface.core.template import Templates +from spikeinterface.core.sparsity import ChannelSparsity @pytest.fixture @@ -41,6 +42,14 @@ def test_numpy_like_behavior(dense_templates): assert np.array_equal(templates[0, :], templates_array[0, :]) assert np.array_equal(templates[0, :, :], templates_array[0, :, :]) assert np.array_equal(templates[3:5, :, 2], templates_array[3:5, :, 2]) + # Test fancy indexing + indices = np.array([0, 1]) + assert np.array_equal(templates[indices], templates_array[indices]) + row_indices = np.array([0, 1]) + col_indices = np.array([2, 3]) + assert np.array_equal(templates[row_indices, col_indices], templates_array[row_indices, col_indices]) + mask = templates_array > 0.5 + assert np.array_equal(templates[mask], templates_array[mask]) # Test unary ufuncs assert np.array_equal(np.sqrt(templates), np.sqrt(templates_array)) @@ -50,7 +59,6 @@ def test_numpy_like_behavior(dense_templates): # Test binary ufuncs other_array = np.random.rand(*templates_array.shape) other_template = Templates(templates_array=other_array) - assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) assert np.array_equal(np.multiply(templates, other_template), np.multiply(templates_array, other_array))