Skip to content

Commit

Permalink
add json test
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Sep 12, 2023
1 parent c242446 commit 4ca9ec6
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 20 deletions.
59 changes: 59 additions & 0 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import json
from dataclasses import dataclass, field

import numpy as np

from spikeinterface.core.sparsity import ChannelSparsity


@dataclass
class Templates:
templates_array: np.ndarray
sparsity: ChannelSparsity = None
num_units: int = field(init=False)
num_samples: int = field(init=False)
num_channels: int = field(init=False)

def __post_init__(self):
self.num_units, self.num_samples, self.num_channels = self.templates_array.shape

# Implementing the slicing/indexing behavior as numpy
def __getitem__(self, index):
return self.templates_array[index]

def __array__(self):
return self.templates_array

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray:
# Replace any Templates instances with their ndarray representation
inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs)

# Apply the ufunc on the transformed inputs
result = getattr(ufunc, method)(*inputs, **kwargs)

return result

def to_dict(self):
sparsity = self.sparsity.to_dict() if self.sparsity is not None else None
return {
"templates_array": self.templates_array.tolist(),
"sparsity": sparsity,
"num_units": self.num_units,
"num_samples": self.num_samples,
"num_channels": self.num_channels,
}

# Construct the object from a dictionary
@classmethod
def from_dict(cls, data):
return cls(
templates_array=np.array(data["templates_array"]),
sparsity=ChannelSparsity(data["sparsity"]), # Assuming you can reconstruct a ChannelSparsity from a string
)

def to_json(self):
return json.dumps(self.to_dict())

@classmethod
def from_json(cls, json_str):
return cls.from_dict(json.loads(json_str))
54 changes: 34 additions & 20 deletions src/spikeinterface/core/tests/test_template_class.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
from pathlib import Path
import pickle
import json

import numpy as np
import pytest

from spikeinterface.core.template import Templates


def test_dense_template_instance():
@pytest.fixture
def dense_templates():
num_units = 2
num_samples = 4
num_channels = 3
templates_shape = (num_units, num_samples, num_channels)
templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape)

templates = Templates(templates_array=templates_array)
return Templates(templates_array=templates_array)


def test_dense_template_instance(dense_templates):
templates = dense_templates
templates_array = templates.templates_array
num_units, num_samples, num_channels = templates_array.shape

assert np.array_equal(templates.templates_array, templates_array)
assert templates.sparsity is None
Expand All @@ -22,14 +31,9 @@ def test_dense_template_instance():
assert templates.num_channels == num_channels


def test_numpy_like_behavior():
num_units = 2
num_samples = 4
num_channels = 3
templates_shape = (num_units, num_samples, num_channels)
templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape)

templates = Templates(templates_array=templates_array)
def test_numpy_like_behavior(dense_templates):
templates = dense_templates
templates_array = templates.templates_array

# Test that slicing works as in numpy
assert np.array_equal(templates[:], templates_array[:])
Expand All @@ -44,7 +48,7 @@ def test_numpy_like_behavior():
assert np.array_equal(np.mean(templates, axis=0), np.mean(templates_array, axis=0))

# Test binary ufuncs
other_array = np.random.rand(*templates_shape)
other_array = np.random.rand(*templates_array.shape)
other_template = Templates(templates_array=other_array)

assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array))
Expand All @@ -60,19 +64,29 @@ def test_numpy_like_behavior():
assert not np.any(np.less(templates, 0))


def test_pickle():
num_units = 2
num_samples = 4
num_channels = 3
templates_shape = (num_units, num_samples, num_channels)
templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape)

templates = Templates(templates_array=templates_array)
def test_pickle(dense_templates):
templates = dense_templates

# Serialize and deserialize the object
serialized = pickle.dumps(templates)
deserialized = pickle.loads(serialized)
deserialized_templates = pickle.loads(serialized)

assert np.array_equal(templates.templates_array, deserialized_templates.templates_array)
assert templates.sparsity == deserialized_templates.sparsity
assert templates.num_units == deserialized_templates.num_units
assert templates.num_samples == deserialized_templates.num_samples
assert templates.num_channels == deserialized_templates.num_channels


def test_jsonification(dense_templates):
templates = dense_templates
# Serialize to JSON string
serialized = templates.to_json()

# Deserialize back to object
deserialized = Templates.from_json(serialized)

# Check if deserialized object matches original
assert np.array_equal(templates.templates_array, deserialized.templates_array)
assert templates.sparsity == deserialized.sparsity
assert templates.num_units == deserialized.num_units
Expand Down

0 comments on commit 4ca9ec6

Please sign in to comment.