Skip to content

Commit

Permalink
test fancy indices
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Sep 12, 2023
1 parent 4ca9ec6 commit 107bdf9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ def to_dict(self):
# Construct the object from a dictionary
@classmethod
def from_dict(cls, data):
sparsity = ChannelSparsity.from_dict(data["sparsity"]) if data["sparsity"] is not None else None
return cls(
templates_array=np.array(data["templates_array"]),
sparsity=ChannelSparsity(data["sparsity"]), # Assuming you can reconstruct a ChannelSparsity from a string
sparsity=sparsity,
)

def to_json(self):
Expand Down
10 changes: 9 additions & 1 deletion src/spikeinterface/core/tests/test_template_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from spikeinterface.core.template import Templates
from spikeinterface.core.sparsity import ChannelSparsity


@pytest.fixture
Expand Down Expand Up @@ -41,6 +42,14 @@ def test_numpy_like_behavior(dense_templates):
assert np.array_equal(templates[0, :], templates_array[0, :])
assert np.array_equal(templates[0, :, :], templates_array[0, :, :])
assert np.array_equal(templates[3:5, :, 2], templates_array[3:5, :, 2])
# Test fancy indexing
indices = np.array([0, 1])
assert np.array_equal(templates[indices], templates_array[indices])
row_indices = np.array([0, 1])
col_indices = np.array([2, 3])
assert np.array_equal(templates[row_indices, col_indices], templates_array[row_indices, col_indices])
mask = templates_array > 0.5
assert np.array_equal(templates[mask], templates_array[mask])

# Test unary ufuncs
assert np.array_equal(np.sqrt(templates), np.sqrt(templates_array))
Expand All @@ -50,7 +59,6 @@ def test_numpy_like_behavior(dense_templates):
# Test binary ufuncs
other_array = np.random.rand(*templates_array.shape)
other_template = Templates(templates_array=other_array)

assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array))
assert np.array_equal(np.multiply(templates, other_template), np.multiply(templates_array, other_array))

Expand Down

0 comments on commit 107bdf9

Please sign in to comment.