From 513a344f1fb8ef942467fd05056f1e6e4a8fa541 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 12:20:52 +0200 Subject: [PATCH 01/21] add basic instance and numpy behavior --- .../core/tests/test_template_class.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/spikeinterface/core/tests/test_template_class.py diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py new file mode 100644 index 0000000000..defad77e00 --- /dev/null +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -0,0 +1,60 @@ +from pathlib import Path + +import numpy as np +import pytest + +from spikeinterface.core.template import Templates + + +def test_dense_template_instance(): + num_units = 2 + num_samples = 4 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + templates = Templates(templates_array=templates_array) + + assert np.array_equal(templates.templates_array, templates_array) + assert templates.sparsity is None + assert templates.num_units == num_units + assert templates.num_samples == num_samples + assert templates.num_channels == num_channels + + +def test_numpy_like_behavior(): + num_units = 2 + num_samples = 4 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + templates = Templates(templates_array=templates_array) + + # Test that slicing works as in numpy + assert np.array_equal(templates[:], templates_array[:]) + assert np.array_equal(templates[0], templates_array[0]) + assert np.array_equal(templates[0, :], templates_array[0, :]) + assert np.array_equal(templates[0, :, :], templates_array[0, :, :]) + assert np.array_equal(templates[3:5, :, 2], templates_array[3:5, :, 2]) + + # Test unary ufuncs + assert np.array_equal(np.sqrt(templates), np.sqrt(templates_array)) + assert np.array_equal(np.abs(templates), np.abs(templates_array)) + assert np.array_equal(np.mean(templates, axis=0), np.mean(templates_array, axis=0)) + + # Test binary ufuncs + other_array = np.random.rand(*templates_shape) + other_template = Templates(templates_array=other_array) + + assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) + assert np.array_equal(np.multiply(templates, other_template), np.multiply(templates_array, other_array)) + + # Test chaining of operations + chained_result = np.mean(np.multiply(templates, other_template), axis=0) + chained_expected = np.mean(np.multiply(templates_array, other_array), axis=0) + assert np.array_equal(chained_result, chained_expected) + + # Test ufuncs that return non-ndarray results + assert np.all(np.greater(templates, -1)) + assert not np.any(np.less(templates, 0)) From c242446086bca416766912fab079e17db100bd03 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 12:24:27 +0200 Subject: [PATCH 02/21] add pickability --- .../core/tests/test_template_class.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index defad77e00..b864421824 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -1,5 +1,5 @@ from pathlib import Path - +import pickle import numpy as np import pytest @@ -58,3 +58,23 @@ def test_numpy_like_behavior(): # Test ufuncs that return non-ndarray results assert np.all(np.greater(templates, -1)) assert not np.any(np.less(templates, 0)) + + +def test_pickle(): + num_units = 2 + num_samples = 4 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + templates = Templates(templates_array=templates_array) + + # Serialize and deserialize the object + serialized = pickle.dumps(templates) + deserialized = pickle.loads(serialized) + + assert np.array_equal(templates.templates_array, deserialized.templates_array) + assert templates.sparsity == deserialized.sparsity + assert templates.num_units == deserialized.num_units + assert templates.num_samples == deserialized.num_samples + assert templates.num_channels == deserialized.num_channels From 4ca9ec6d504adb83561ce8bb57189ff3e9ec7a8e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 12:40:45 +0200 Subject: [PATCH 03/21] add json test --- src/spikeinterface/core/template.py | 59 +++++++++++++++++++ .../core/tests/test_template_class.py | 54 ++++++++++------- 2 files changed, 93 insertions(+), 20 deletions(-) create mode 100644 src/spikeinterface/core/template.py diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py new file mode 100644 index 0000000000..3542879df4 --- /dev/null +++ b/src/spikeinterface/core/template.py @@ -0,0 +1,59 @@ +import json +from dataclasses import dataclass, field + +import numpy as np + +from spikeinterface.core.sparsity import ChannelSparsity + + +@dataclass +class Templates: + templates_array: np.ndarray + sparsity: ChannelSparsity = None + num_units: int = field(init=False) + num_samples: int = field(init=False) + num_channels: int = field(init=False) + + def __post_init__(self): + self.num_units, self.num_samples, self.num_channels = self.templates_array.shape + + # Implementing the slicing/indexing behavior as numpy + def __getitem__(self, index): + return self.templates_array[index] + + def __array__(self): + return self.templates_array + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: + # Replace any Templates instances with their ndarray representation + inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) + + # Apply the ufunc on the transformed inputs + result = getattr(ufunc, method)(*inputs, **kwargs) + + return result + + def to_dict(self): + sparsity = self.sparsity.to_dict() if self.sparsity is not None else None + return { + "templates_array": self.templates_array.tolist(), + "sparsity": sparsity, + "num_units": self.num_units, + "num_samples": self.num_samples, + "num_channels": self.num_channels, + } + + # Construct the object from a dictionary + @classmethod + def from_dict(cls, data): + return cls( + templates_array=np.array(data["templates_array"]), + sparsity=ChannelSparsity(data["sparsity"]), # Assuming you can reconstruct a ChannelSparsity from a string + ) + + def to_json(self): + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str): + return cls.from_dict(json.loads(json_str)) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index b864421824..e9e3f60730 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -1,19 +1,28 @@ from pathlib import Path import pickle +import json + import numpy as np import pytest from spikeinterface.core.template import Templates -def test_dense_template_instance(): +@pytest.fixture +def dense_templates(): num_units = 2 num_samples = 4 num_channels = 3 templates_shape = (num_units, num_samples, num_channels) templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - templates = Templates(templates_array=templates_array) + return Templates(templates_array=templates_array) + + +def test_dense_template_instance(dense_templates): + templates = dense_templates + templates_array = templates.templates_array + num_units, num_samples, num_channels = templates_array.shape assert np.array_equal(templates.templates_array, templates_array) assert templates.sparsity is None @@ -22,14 +31,9 @@ def test_dense_template_instance(): assert templates.num_channels == num_channels -def test_numpy_like_behavior(): - num_units = 2 - num_samples = 4 - num_channels = 3 - templates_shape = (num_units, num_samples, num_channels) - templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - - templates = Templates(templates_array=templates_array) +def test_numpy_like_behavior(dense_templates): + templates = dense_templates + templates_array = templates.templates_array # Test that slicing works as in numpy assert np.array_equal(templates[:], templates_array[:]) @@ -44,7 +48,7 @@ def test_numpy_like_behavior(): assert np.array_equal(np.mean(templates, axis=0), np.mean(templates_array, axis=0)) # Test binary ufuncs - other_array = np.random.rand(*templates_shape) + other_array = np.random.rand(*templates_array.shape) other_template = Templates(templates_array=other_array) assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) @@ -60,19 +64,29 @@ def test_numpy_like_behavior(): assert not np.any(np.less(templates, 0)) -def test_pickle(): - num_units = 2 - num_samples = 4 - num_channels = 3 - templates_shape = (num_units, num_samples, num_channels) - templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - - templates = Templates(templates_array=templates_array) +def test_pickle(dense_templates): + templates = dense_templates # Serialize and deserialize the object serialized = pickle.dumps(templates) - deserialized = pickle.loads(serialized) + deserialized_templates = pickle.loads(serialized) + + assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) + assert templates.sparsity == deserialized_templates.sparsity + assert templates.num_units == deserialized_templates.num_units + assert templates.num_samples == deserialized_templates.num_samples + assert templates.num_channels == deserialized_templates.num_channels + + +def test_jsonification(dense_templates): + templates = dense_templates + # Serialize to JSON string + serialized = templates.to_json() + + # Deserialize back to object + deserialized = Templates.from_json(serialized) + # Check if deserialized object matches original assert np.array_equal(templates.templates_array, deserialized.templates_array) assert templates.sparsity == deserialized.sparsity assert templates.num_units == deserialized.num_units From 107bdf96ec26368dfa0666c2e7ac9bd5dd596563 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 17:20:31 +0200 Subject: [PATCH 04/21] test fancy indices --- src/spikeinterface/core/template.py | 3 ++- src/spikeinterface/core/tests/test_template_class.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 3542879df4..2906692902 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -46,9 +46,10 @@ def to_dict(self): # Construct the object from a dictionary @classmethod def from_dict(cls, data): + sparsity = ChannelSparsity.from_dict(data["sparsity"]) if data["sparsity"] is not None else None return cls( templates_array=np.array(data["templates_array"]), - sparsity=ChannelSparsity(data["sparsity"]), # Assuming you can reconstruct a ChannelSparsity from a string + sparsity=sparsity, ) def to_json(self): diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index e9e3f60730..62673906ab 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -6,6 +6,7 @@ import pytest from spikeinterface.core.template import Templates +from spikeinterface.core.sparsity import ChannelSparsity @pytest.fixture @@ -41,6 +42,14 @@ def test_numpy_like_behavior(dense_templates): assert np.array_equal(templates[0, :], templates_array[0, :]) assert np.array_equal(templates[0, :, :], templates_array[0, :, :]) assert np.array_equal(templates[3:5, :, 2], templates_array[3:5, :, 2]) + # Test fancy indexing + indices = np.array([0, 1]) + assert np.array_equal(templates[indices], templates_array[indices]) + row_indices = np.array([0, 1]) + col_indices = np.array([2, 3]) + assert np.array_equal(templates[row_indices, col_indices], templates_array[row_indices, col_indices]) + mask = templates_array > 0.5 + assert np.array_equal(templates[mask], templates_array[mask]) # Test unary ufuncs assert np.array_equal(np.sqrt(templates), np.sqrt(templates_array)) @@ -50,7 +59,6 @@ def test_numpy_like_behavior(dense_templates): # Test binary ufuncs other_array = np.random.rand(*templates_array.shape) other_template = Templates(templates_array=other_array) - assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) assert np.array_equal(np.multiply(templates, other_template), np.multiply(templates_array, other_array)) From d2c6ec727c01c47ec9635a5279eb3d432ce370ce Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 22:39:01 +0200 Subject: [PATCH 05/21] alessio and samuel requests --- src/spikeinterface/core/template.py | 90 ++++++--- .../core/tests/test_template_class.py | 178 +++++++++--------- 2 files changed, 144 insertions(+), 124 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 2906692902..4b923db2df 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -1,55 +1,67 @@ +import numpy as np import json from dataclasses import dataclass, field +from .sparsity import ChannelSparsity -import numpy as np - -from spikeinterface.core.sparsity import ChannelSparsity - -@dataclass +@dataclass(kw_only=True) class Templates: templates_array: np.ndarray - sparsity: ChannelSparsity = None + sampling_frequency: float + nbefore: int + + sparsity_mask: np.ndarray = None + channel_ids: np.ndarray = None + unit_ids: np.ndarray = None + num_units: int = field(init=False) num_samples: int = field(init=False) num_channels: int = field(init=False) - def __post_init__(self): - self.num_units, self.num_samples, self.num_channels = self.templates_array.shape - - # Implementing the slicing/indexing behavior as numpy - def __getitem__(self, index): - return self.templates_array[index] - - def __array__(self): - return self.templates_array - - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: - # Replace any Templates instances with their ndarray representation - inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) + nafter: int = field(init=False) + ms_before: float = field(init=False) + ms_after: float = field(init=False) + sparsity: ChannelSparsity = field(init=False) - # Apply the ufunc on the transformed inputs - result = getattr(ufunc, method)(*inputs, **kwargs) - - return result + def __post_init__(self): + self.num_units, self.num_samples = self.templates_array.shape[:2] + if self.sparsity_mask is None: + self.num_channels = self.templates_array.shape[2] + else: + self.num_channels = self.sparsity_mask.shape[1] + self.nafter = self.num_samples - self.nbefore - 1 + self.ms_before = self.nbefore / self.sampling_frequency * 1000 + self.ms_after = self.nafter / self.sampling_frequency * 1000 + if self.channel_ids is None: + self.channel_ids = np.arange(self.num_channels) + if self.unit_ids is None: + self.unit_ids = np.arange(self.num_units) + if self.sparsity_mask is not None: + self.sparsity = ChannelSparsity( + mask=self.sparsity_mask, + unit_ids=self.unit_ids, + channel_ids=self.channel_ids, + ) def to_dict(self): - sparsity = self.sparsity.to_dict() if self.sparsity is not None else None return { "templates_array": self.templates_array.tolist(), - "sparsity": sparsity, - "num_units": self.num_units, - "num_samples": self.num_samples, - "num_channels": self.num_channels, + "sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask.tolist(), + "channel_ids": self.channel_ids.tolist(), + "unit_ids": self.unit_ids.tolist(), + "sampling_frequency": self.sampling_frequency, + "nbefore": self.nbefore, } - # Construct the object from a dictionary @classmethod def from_dict(cls, data): - sparsity = ChannelSparsity.from_dict(data["sparsity"]) if data["sparsity"] is not None else None return cls( templates_array=np.array(data["templates_array"]), - sparsity=sparsity, + sparsity_mask=None if data["sparsity_mask"] is None else np.array(data["sparsity_mask"]), + channel_ids=np.array(data["channel_ids"]), + unit_ids=np.array(data["unit_ids"]), + sampling_frequency=data["sampling_frequency"], + nbefore=data["nbefore"], ) def to_json(self): @@ -58,3 +70,19 @@ def to_json(self): @classmethod def from_json(cls, json_str): return cls.from_dict(json.loads(json_str)) + + # Implementing the slicing/indexing behavior as numpy + def __getitem__(self, index): + return self.templates_array[index] + + def __array__(self): + return self.templates_array + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: + # Replace any Templates instances with their ndarray representation + inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) + + # Apply the ufunc on the transformed inputs + result = getattr(ufunc, method)(*inputs, **kwargs) + + return result diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 62673906ab..5fc997c6bf 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -1,102 +1,94 @@ -from pathlib import Path -import pickle -import json - -import numpy as np import pytest - +import numpy as np +import pickle from spikeinterface.core.template import Templates -from spikeinterface.core.sparsity import ChannelSparsity -@pytest.fixture -def dense_templates(): +@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) +def get_template_object(template_obj): num_units = 2 - num_samples = 4 + num_samples = 5 num_channels = 3 templates_shape = (num_units, num_samples, num_channels) templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - return Templates(templates_array=templates_array) - - -def test_dense_template_instance(dense_templates): - templates = dense_templates - templates_array = templates.templates_array - num_units, num_samples, num_channels = templates_array.shape - - assert np.array_equal(templates.templates_array, templates_array) - assert templates.sparsity is None - assert templates.num_units == num_units - assert templates.num_samples == num_samples - assert templates.num_channels == num_channels - - -def test_numpy_like_behavior(dense_templates): - templates = dense_templates - templates_array = templates.templates_array - - # Test that slicing works as in numpy - assert np.array_equal(templates[:], templates_array[:]) - assert np.array_equal(templates[0], templates_array[0]) - assert np.array_equal(templates[0, :], templates_array[0, :]) - assert np.array_equal(templates[0, :, :], templates_array[0, :, :]) - assert np.array_equal(templates[3:5, :, 2], templates_array[3:5, :, 2]) - # Test fancy indexing - indices = np.array([0, 1]) - assert np.array_equal(templates[indices], templates_array[indices]) - row_indices = np.array([0, 1]) - col_indices = np.array([2, 3]) - assert np.array_equal(templates[row_indices, col_indices], templates_array[row_indices, col_indices]) - mask = templates_array > 0.5 - assert np.array_equal(templates[mask], templates_array[mask]) - - # Test unary ufuncs - assert np.array_equal(np.sqrt(templates), np.sqrt(templates_array)) - assert np.array_equal(np.abs(templates), np.abs(templates_array)) - assert np.array_equal(np.mean(templates, axis=0), np.mean(templates_array, axis=0)) - - # Test binary ufuncs - other_array = np.random.rand(*templates_array.shape) - other_template = Templates(templates_array=other_array) - assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) - assert np.array_equal(np.multiply(templates, other_template), np.multiply(templates_array, other_array)) - - # Test chaining of operations - chained_result = np.mean(np.multiply(templates, other_template), axis=0) - chained_expected = np.mean(np.multiply(templates_array, other_array), axis=0) - assert np.array_equal(chained_result, chained_expected) - - # Test ufuncs that return non-ndarray results - assert np.all(np.greater(templates, -1)) - assert not np.any(np.less(templates, 0)) - - -def test_pickle(dense_templates): - templates = dense_templates - - # Serialize and deserialize the object - serialized = pickle.dumps(templates) - deserialized_templates = pickle.loads(serialized) - - assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) - assert templates.sparsity == deserialized_templates.sparsity - assert templates.num_units == deserialized_templates.num_units - assert templates.num_samples == deserialized_templates.num_samples - assert templates.num_channels == deserialized_templates.num_channels - - -def test_jsonification(dense_templates): - templates = dense_templates - # Serialize to JSON string - serialized = templates.to_json() - - # Deserialize back to object - deserialized = Templates.from_json(serialized) - - # Check if deserialized object matches original - assert np.array_equal(templates.templates_array, deserialized.templates_array) - assert templates.sparsity == deserialized.sparsity - assert templates.num_units == deserialized.num_units - assert templates.num_samples == deserialized.num_samples - assert templates.num_channels == deserialized.num_channels + sampling_frequency = 30_000 + nbefore = 2 + + if template_obj == "dense": + return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) + else: # sparse + sparsity_mask = np.array([[True, False, True], [False, True, False]]) + return Templates( + templates_array=templates_array, + sparsity_mask=sparsity_mask, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + ) + + +@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) +def test_pickle_serialization(template_obj, tmp_path): + obj = get_template_object(template_obj) + + # Dump to pickle + pkl_path = tmp_path / "templates.pkl" + with open(pkl_path, "wb") as f: + pickle.dump(obj, f) + + # Load from pickle + with open(pkl_path, "rb") as f: + loaded_obj = pickle.load(f) + + assert np.array_equal(obj.templates_array, loaded_obj.templates_array) + + +@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) +def test_json_serialization(template_obj): + obj = get_template_object(template_obj) + + json_str = obj.to_json() + loaded_obj_from_json = Templates.from_json(json_str) + + assert np.array_equal(obj.templates_array, loaded_obj_from_json.templates_array) + + +# @pytest.fixture +# def dense_templates(): +# num_units = 2 +# num_samples = 4 +# num_channels = 3 +# templates_shape = (num_units, num_samples, num_channels) +# templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + +# return Templates(templates_array=templates_array) + + +# def test_pickle(dense_templates): +# templates = dense_templates + +# # Serialize and deserialize the object +# serialized = pickle.dumps(templates) +# deserialized_templates = pickle.loads(serialized) + +# assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) +# assert templates.sparsity == deserialized_templates.sparsity +# assert templates.num_units == deserialized_templates.num_units +# assert templates.num_samples == deserialized_templates.num_samples +# assert templates.num_channels == deserialized_templates.num_channels + + +# def test_jsonification(dense_templates): +# templates = dense_templates +# # Serialize to JSON string +# serialized = templates.to_json() + +# # Deserialize back to object +# deserialized = Templates.from_json(serialized) + +# # Check if deserialized object matches original +# assert np.array_equal(templates.templates_array, deserialized.templates_array) +# assert templates.sparsity == deserialized.sparsity +# assert templates.num_units == deserialized.num_units +# assert templates.num_samples == deserialized.num_samples +# assert templates.num_channels == deserialized.num_channels From 961d26979cf339b8799d9f045f20f477589171c4 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 22:39:38 +0200 Subject: [PATCH 06/21] remove slicing --- src/spikeinterface/core/template.py | 20 --------- .../core/tests/test_template_class.py | 41 ------------------- 2 files changed, 61 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 4b923db2df..6dbfb881f6 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -66,23 +66,3 @@ def from_dict(cls, data): def to_json(self): return json.dumps(self.to_dict()) - - @classmethod - def from_json(cls, json_str): - return cls.from_dict(json.loads(json_str)) - - # Implementing the slicing/indexing behavior as numpy - def __getitem__(self, index): - return self.templates_array[index] - - def __array__(self): - return self.templates_array - - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: - # Replace any Templates instances with their ndarray representation - inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) - - # Apply the ufunc on the transformed inputs - result = getattr(ufunc, method)(*inputs, **kwargs) - - return result diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 5fc997c6bf..2b6b4c9744 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -51,44 +51,3 @@ def test_json_serialization(template_obj): loaded_obj_from_json = Templates.from_json(json_str) assert np.array_equal(obj.templates_array, loaded_obj_from_json.templates_array) - - -# @pytest.fixture -# def dense_templates(): -# num_units = 2 -# num_samples = 4 -# num_channels = 3 -# templates_shape = (num_units, num_samples, num_channels) -# templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - -# return Templates(templates_array=templates_array) - - -# def test_pickle(dense_templates): -# templates = dense_templates - -# # Serialize and deserialize the object -# serialized = pickle.dumps(templates) -# deserialized_templates = pickle.loads(serialized) - -# assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) -# assert templates.sparsity == deserialized_templates.sparsity -# assert templates.num_units == deserialized_templates.num_units -# assert templates.num_samples == deserialized_templates.num_samples -# assert templates.num_channels == deserialized_templates.num_channels - - -# def test_jsonification(dense_templates): -# templates = dense_templates -# # Serialize to JSON string -# serialized = templates.to_json() - -# # Deserialize back to object -# deserialized = Templates.from_json(serialized) - -# # Check if deserialized object matches original -# assert np.array_equal(templates.templates_array, deserialized.templates_array) -# assert templates.sparsity == deserialized.sparsity -# assert templates.num_units == deserialized.num_units -# assert templates.num_samples == deserialized.num_samples -# assert templates.num_channels == deserialized.num_channels From 9ee3a1de6741b156f2a56ca5f7455dd2ebf3b768 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 23:15:08 +0200 Subject: [PATCH 07/21] passing tests --- src/spikeinterface/core/template.py | 41 +++++++++++++- .../core/tests/test_template_class.py | 55 ++++++++++++++----- 2 files changed, 79 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 6dbfb881f6..8926281dfe 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -1,6 +1,6 @@ import numpy as np import json -from dataclasses import dataclass, field +from dataclasses import dataclass, field, astuple from .sparsity import ChannelSparsity @@ -21,7 +21,7 @@ class Templates: nafter: int = field(init=False) ms_before: float = field(init=False) ms_after: float = field(init=False) - sparsity: ChannelSparsity = field(init=False) + sparsity: ChannelSparsity = field(init=False, default=None) def __post_init__(self): self.num_units, self.num_samples = self.templates_array.shape[:2] @@ -66,3 +66,40 @@ def from_dict(cls, data): def to_json(self): return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str): + return cls.from_dict(json.loads(json_str)) + + def __eq__(self, other): + """Necessary to compare arrays""" + if not isinstance(other, Templates): + return False + + # Convert the instances to tuples + self_tuple = astuple(self) + other_tuple = astuple(other) + + # Compare each field + for s_field, o_field in zip(self_tuple, other_tuple): + if isinstance(s_field, np.ndarray): + if not np.array_equal(s_field, o_field): + return False + + elif isinstance(s_field, ChannelSparsity): + if not isinstance(o_field, ChannelSparsity): + return False + + # (maybe ChannelSparsity should have its own __eq__ method) + # Compare ChannelSparsity by its mask, unit_ids and channel_ids + if not np.array_equal(s_field.mask, o_field.mask): + return False + if not np.array_equal(s_field.unit_ids, o_field.unit_ids): + return False + if not np.array_equal(s_field.channel_ids, o_field.channel_ids): + return False + else: + if s_field != o_field: + return False + + return True diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 2b6b4c9744..b395f82d49 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -4,8 +4,33 @@ from spikeinterface.core.template import Templates -@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) -def get_template_object(template_obj): +def compare_instances(obj1, obj2): + if not isinstance(obj1, Templates) or not isinstance(obj2, Templates): + raise ValueError("Both objects must be instances of the Templates class") + + for attr, value1 in obj1.__dict__.items(): + value2 = getattr(obj2, attr, None) + + # Comparing numpy arrays + if isinstance(value1, np.ndarray): + if not np.array_equal(value1, value2): + print(f"Attribute '{attr}' is not equal!") + print(f"Value from obj1:\n{value1}") + print(f"Value from obj2:\n{value2}") + return False + # Comparing other types + elif value1 != value2: + print(f"Attribute '{attr}' is not equal!") + print(f"Value from obj1: {value1}") + print(f"Value from obj2: {value2}") + return False + + print("All attributes are equal!") + return True + + +@pytest.mark.parametrize("template_object", ["dense", "sparse"]) +def generate_template_fixture(template_object): num_units = 2 num_samples = 5 num_channels = 3 @@ -15,7 +40,7 @@ def get_template_object(template_obj): sampling_frequency = 30_000 nbefore = 2 - if template_obj == "dense": + if template_object == "dense": return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) else: # sparse sparsity_mask = np.array([[True, False, True], [False, True, False]]) @@ -27,27 +52,27 @@ def get_template_object(template_obj): ) -@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) -def test_pickle_serialization(template_obj, tmp_path): - obj = get_template_object(template_obj) +@pytest.mark.parametrize("template_object", ["dense", "sparse"]) +def test_pickle_serialization(template_object, tmp_path): + template = generate_template_fixture(template_object) # Dump to pickle pkl_path = tmp_path / "templates.pkl" with open(pkl_path, "wb") as f: - pickle.dump(obj, f) + pickle.dump(template, f) # Load from pickle with open(pkl_path, "rb") as f: - loaded_obj = pickle.load(f) + template_reloaded = pickle.load(f) - assert np.array_equal(obj.templates_array, loaded_obj.templates_array) + assert template == template_reloaded -@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) -def test_json_serialization(template_obj): - obj = get_template_object(template_obj) +@pytest.mark.parametrize("template_object", ["dense", "sparse"]) +def test_json_serialization(template_object): + template = generate_template_fixture(template_object) - json_str = obj.to_json() - loaded_obj_from_json = Templates.from_json(json_str) + json_str = template.to_json() + template_reloaded_from_json = Templates.from_json(json_str) - assert np.array_equal(obj.templates_array, loaded_obj_from_json.templates_array) + assert template == template_reloaded_from_json From 9d7c9ac134151cb688ec4316c9bc2ff4af9f62ae Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 23:21:13 +0200 Subject: [PATCH 08/21] add densification and sparsification methods --- src/spikeinterface/core/template.py | 12 +++++++++ .../core/tests/test_template_class.py | 25 ------------------- 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 8926281dfe..70c7d90527 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -64,6 +64,18 @@ def from_dict(cls, data): nbefore=data["nbefore"], ) + def get_dense_templates(self) -> np.ndarray: + if self.sparsity is None: + return self.templates_array + else: + self.sparsity.to_dense(self.templates_array) + + def get_sparse_templates(self) -> np.ndarray: + if self.sparsity is None: + raise ValueError("Can't return sparse templates without passing a sparsity mask") + else: + self.sparsity.to_sparse(self.templates_array) + def to_json(self): return json.dumps(self.to_dict()) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index b395f82d49..f92e636d93 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -4,31 +4,6 @@ from spikeinterface.core.template import Templates -def compare_instances(obj1, obj2): - if not isinstance(obj1, Templates) or not isinstance(obj2, Templates): - raise ValueError("Both objects must be instances of the Templates class") - - for attr, value1 in obj1.__dict__.items(): - value2 = getattr(obj2, attr, None) - - # Comparing numpy arrays - if isinstance(value1, np.ndarray): - if not np.array_equal(value1, value2): - print(f"Attribute '{attr}' is not equal!") - print(f"Value from obj1:\n{value1}") - print(f"Value from obj2:\n{value2}") - return False - # Comparing other types - elif value1 != value2: - print(f"Attribute '{attr}' is not equal!") - print(f"Value from obj1: {value1}") - print(f"Value from obj2: {value2}") - return False - - print("All attributes are equal!") - return True - - @pytest.mark.parametrize("template_object", ["dense", "sparse"]) def generate_template_fixture(template_object): num_units = 2 From 6e1027bb75abc9ae1977ec0a6b87c7f951588576 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 25 Sep 2023 10:51:11 +0200 Subject: [PATCH 09/21] adding tests for sparsity and density --- src/spikeinterface/core/sparsity.py | 3 +- src/spikeinterface/core/template.py | 53 ++++++++++++-- .../core/tests/test_template_class.py | 71 ++++++++++++++++--- 3 files changed, 109 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 455edcfc80..70b412d487 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -191,7 +191,7 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray: assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg densified_shape = waveforms.shape[:-1] + (self.num_channels,) - densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype) + densified_waveforms = np.zeros(shape=densified_shape, dtype=waveforms.dtype) densified_waveforms[..., non_zero_indices] = waveforms return densified_waveforms @@ -202,6 +202,7 @@ def are_waveforms_dense(self, waveforms: np.ndarray) -> bool: def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str) -> bool: non_zero_indices = self.unit_id_to_channel_indices[unit_id] num_active_channels = len(non_zero_indices) + return waveforms.shape[-1] == num_active_channels @classmethod diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 70c7d90527..54070053eb 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -65,16 +65,52 @@ def from_dict(cls, data): ) def get_dense_templates(self) -> np.ndarray: + # Assumes and object without a sparsity mask already has dense templates if self.sparsity is None: return self.templates_array - else: - self.sparsity.to_dense(self.templates_array) + + dense_waveforms = np.zeros(shape=(self.num_units, self.num_samples, self.num_channels)) + for unit_index, unit_id in enumerate(self.unit_ids): + num_active_channels = self.sparsity.mask[unit_index].sum() + waveforms = self.templates_array[unit_index, :, :num_active_channels] + dense_waveforms[unit_index, ...] = self.sparsity.densify_waveforms(waveforms=waveforms, unit_id=unit_id) + + return dense_waveforms def get_sparse_templates(self) -> np.ndarray: + # Objects without sparsity mask don't have sparsity and therefore can't return sparse templates if self.sparsity is None: raise ValueError("Can't return sparse templates without passing a sparsity mask") - else: - self.sparsity.to_sparse(self.templates_array) + + # Waveforms are already sparse + if not self.sparsity.are_waveforms_dense(self.templates_array): + return self.templates_array + + max_num_active_channels = self.sparsity.max_num_active_channels + sparse_waveforms = np.zeros(shape=(self.num_units, self.num_samples, max_num_active_channels)) + for unit_index, unit_id in enumerate(self.unit_ids): + waveforms = self.templates_array[unit_index, ...] + sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) + + return sparse_waveforms + + def are_templates_sparse(self) -> bool: + if self.sparsity is None: + return False + + if self.templates_array.shape[-1] == self.num_channels: + return False + + unit_is_sparse = True + for unit_index, unit_id in enumerate(self.unit_ids): + non_zero_indices = self.sparsity.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + waveforms = self.templates_array[unit_index, :, :num_active_channels] + unit_is_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id) + if not unit_is_sparse: + return False + + return unit_is_sparse def to_json(self): return json.dumps(self.to_dict()) @@ -84,7 +120,11 @@ def from_json(cls, json_str): return cls.from_dict(json.loads(json_str)) def __eq__(self, other): - """Necessary to compare arrays""" + """ + Necessary to compare templates because they naturally compare objects by equality of their fields + which is not possible for numpy arrays so we override the __eq__ method to compare each numpy arrays + with np.array_equal + """ if not isinstance(other, Templates): return False @@ -97,12 +137,11 @@ def __eq__(self, other): if isinstance(s_field, np.ndarray): if not np.array_equal(s_field, o_field): return False - + # Compare ChannelSparsity by its mask, unit_ids and channel_ids. Maybe ChannelSparsity should have its own __eq__ method elif isinstance(s_field, ChannelSparsity): if not isinstance(o_field, ChannelSparsity): return False - # (maybe ChannelSparsity should have its own __eq__ method) # Compare ChannelSparsity by its mask, unit_ids and channel_ids if not np.array_equal(s_field.mask, o_field.mask): return False diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index f92e636d93..cf0dffe532 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -2,10 +2,10 @@ import numpy as np import pickle from spikeinterface.core.template import Templates +from spikeinterface.core.sparsity import ChannelSparsity -@pytest.mark.parametrize("template_object", ["dense", "sparse"]) -def generate_template_fixture(template_object): +def generate_test_template(template_type): num_units = 2 num_samples = 5 num_channels = 3 @@ -15,10 +15,29 @@ def generate_template_fixture(template_object): sampling_frequency = 30_000 nbefore = 2 - if template_object == "dense": + if template_type == "dense": return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) - else: # sparse + elif template_type == "sparse": # sparse with sparse templates sparsity_mask = np.array([[True, False, True], [False, True, False]]) + sparsity = ChannelSparsity( + mask=sparsity_mask, unit_ids=np.arange(num_units), channel_ids=np.arange(num_channels) + ) + + sparse_templates_array = np.zeros(shape=(num_units, num_samples, sparsity.max_num_active_channels)) + for unit_index in range(num_units): + template = templates_array[unit_index, ...] + sparse_template = sparsity.sparsify_waveforms(waveforms=template, unit_id=unit_index) + sparse_templates_array[unit_index, :, : sparse_template.shape[1]] = sparse_template + + return Templates( + templates_array=sparse_templates_array, + sparsity_mask=sparsity_mask, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + ) + elif template_type == "sparse_with_dense_templates": # sparse with dense templates + sparsity_mask = np.array([[True, False, True], [False, True, False]]) + return Templates( templates_array=templates_array, sparsity_mask=sparsity_mask, @@ -27,9 +46,9 @@ def generate_template_fixture(template_object): ) -@pytest.mark.parametrize("template_object", ["dense", "sparse"]) -def test_pickle_serialization(template_object, tmp_path): - template = generate_template_fixture(template_object) +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_pickle_serialization(template_type, tmp_path): + template = generate_test_template(template_type) # Dump to pickle pkl_path = tmp_path / "templates.pkl" @@ -43,11 +62,43 @@ def test_pickle_serialization(template_object, tmp_path): assert template == template_reloaded -@pytest.mark.parametrize("template_object", ["dense", "sparse"]) -def test_json_serialization(template_object): - template = generate_template_fixture(template_object) +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_json_serialization(template_type): + template = generate_test_template(template_type) json_str = template.to_json() template_reloaded_from_json = Templates.from_json(json_str) assert template == template_reloaded_from_json + + +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_get_dense_templates(template_type): + template = generate_test_template(template_type) + dense_templates = template.get_dense_templates() + assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) + + +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_get_sparse_templates(template_type): + template = generate_test_template(template_type) + + if template_type == "dense": + with pytest.raises(ValueError): + sparse_templates = template.get_sparse_templates() + elif template_type == "sparse": + sparse_templates = template.get_sparse_templates() + assert sparse_templates.shape == ( + template.num_units, + template.num_samples, + template.sparsity.max_num_active_channels, + ) + assert template.are_templates_sparse() + elif template_type == "sparse_with_dense_templates": + sparse_templates = template.get_sparse_templates() + assert sparse_templates.shape == ( + template.num_units, + template.num_samples, + template.sparsity.max_num_active_channels, + ) + assert not template.are_templates_sparse() From d05e67da479716c7aac619dc0e4d3080a0a6d4a5 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:24:45 +0200 Subject: [PATCH 10/21] prohibit dense templates when passing sparsity mask --- src/spikeinterface/core/sparsity.py | 29 ++++++----- src/spikeinterface/core/template.py | 50 ++++++++++--------- .../core/tests/test_template_class.py | 23 ++++----- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 5bc2e51e8a..f2da16b757 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -150,11 +150,8 @@ def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nd or a single sparsified waveform (template) with shape (num_samples, num_active_channels). """ - assert_msg = ( - "Waveforms must be dense to sparsify them. " - f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}" - ) - assert self.are_waveforms_dense(waveforms=waveforms), assert_msg + if self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): + return waveforms non_zero_indices = self.unit_id_to_channel_indices[unit_id] sparsified_waveforms = waveforms[..., non_zero_indices] @@ -185,16 +182,20 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nda """ non_zero_indices = self.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) - assert_msg = ( - "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " - f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels." - ) - assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg + if not self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): + error_message = ( + "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " + f"{num_active_channels} but the waveform has non zero values outsies of those active channels: \n" + f"{waveforms[..., num_active_channels:]}" + ) + raise ValueError(error_message) densified_shape = waveforms.shape[:-1] + (self.num_channels,) densified_waveforms = np.zeros(shape=densified_shape, dtype=waveforms.dtype) - densified_waveforms[..., non_zero_indices] = waveforms + # Maps the active channels to their original indices + densified_waveforms[..., non_zero_indices] = waveforms[..., :num_active_channels] return densified_waveforms @@ -205,7 +206,11 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo non_zero_indices = self.unit_id_to_channel_indices[unit_id] num_active_channels = len(non_zero_indices) - return waveforms.shape[-1] == num_active_channels + # If any channel is non-zero outside of the active channels, then the waveforms are not sparse + excess_zeros = waveforms[..., num_active_channels:].sum() + are_sparse = excess_zeros == 0 + + return are_sparse @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 54070053eb..c0d4869d5e 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -32,6 +32,8 @@ def __post_init__(self): self.nafter = self.num_samples - self.nbefore - 1 self.ms_before = self.nbefore / self.sampling_frequency * 1000 self.ms_after = self.nafter / self.sampling_frequency * 1000 + + # Initialize sparsity object if self.channel_ids is None: self.channel_ids = np.arange(self.num_channels) if self.unit_ids is None: @@ -43,6 +45,10 @@ def __post_init__(self): channel_ids=self.channel_ids, ) + # Test that the templates are sparse if a sparsity mask is passed + if not self._are_passed_templates_sparse(): + raise ValueError("Sparsity mask passed but the templates are not sparse") + def to_dict(self): return { "templates_array": self.templates_array.tolist(), @@ -69,10 +75,11 @@ def get_dense_templates(self) -> np.ndarray: if self.sparsity is None: return self.templates_array - dense_waveforms = np.zeros(shape=(self.num_units, self.num_samples, self.num_channels)) + dense_shape = (self.num_units, self.num_samples, self.num_channels) + dense_waveforms = np.zeros(dense=dense_shape, dtype=self.templates_array.dtype) + for unit_index, unit_id in enumerate(self.unit_ids): - num_active_channels = self.sparsity.mask[unit_index].sum() - waveforms = self.templates_array[unit_index, :, :num_active_channels] + waveforms = self.templates_array[unit_index, ...] dense_waveforms[unit_index, ...] = self.sparsity.densify_waveforms(waveforms=waveforms, unit_id=unit_id) return dense_waveforms @@ -82,12 +89,9 @@ def get_sparse_templates(self) -> np.ndarray: if self.sparsity is None: raise ValueError("Can't return sparse templates without passing a sparsity mask") - # Waveforms are already sparse - if not self.sparsity.are_waveforms_dense(self.templates_array): - return self.templates_array - max_num_active_channels = self.sparsity.max_num_active_channels - sparse_waveforms = np.zeros(shape=(self.num_units, self.num_samples, max_num_active_channels)) + sparse_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparse_waveforms = np.zeros(shape=sparse_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) @@ -95,22 +99,20 @@ def get_sparse_templates(self) -> np.ndarray: return sparse_waveforms def are_templates_sparse(self) -> bool: - if self.sparsity is None: - return False - - if self.templates_array.shape[-1] == self.num_channels: - return False + return self.sparsity is not None - unit_is_sparse = True + def _are_passed_templates_sparse(self) -> bool: + """ + Tests if the templates passed to the init constructor are sparse + """ + are_templates_sparse = True for unit_index, unit_id in enumerate(self.unit_ids): - non_zero_indices = self.sparsity.unit_id_to_channel_indices[unit_id] - num_active_channels = len(non_zero_indices) - waveforms = self.templates_array[unit_index, :, :num_active_channels] - unit_is_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id) - if not unit_is_sparse: + waveforms = self.templates_array[unit_index, ...] + are_templates_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id) + if not are_templates_sparse: return False - return unit_is_sparse + return are_templates_sparse def to_json(self): return json.dumps(self.to_dict()) @@ -122,8 +124,8 @@ def from_json(cls, json_str): def __eq__(self, other): """ Necessary to compare templates because they naturally compare objects by equality of their fields - which is not possible for numpy arrays so we override the __eq__ method to compare each numpy arrays - with np.array_equal + which is not possible for numpy arrays. Therefore, we override the __eq__ method to compare each numpy arrays + using np.array_equal instead """ if not isinstance(other, Templates): return False @@ -137,7 +139,9 @@ def __eq__(self, other): if isinstance(s_field, np.ndarray): if not np.array_equal(s_field, o_field): return False - # Compare ChannelSparsity by its mask, unit_ids and channel_ids. Maybe ChannelSparsity should have its own __eq__ method + + # Compare ChannelSparsity by its mask, unit_ids and channel_ids. + # Maybe ChannelSparsity should have its own __eq__ method elif isinstance(s_field, ChannelSparsity): if not isinstance(o_field, ChannelSparsity): return False diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index cf0dffe532..b1244ab0d1 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -23,6 +23,7 @@ def generate_test_template(template_type): mask=sparsity_mask, unit_ids=np.arange(num_units), channel_ids=np.arange(num_channels) ) + # Create sparse templates sparse_templates_array = np.zeros(shape=(num_units, num_samples, sparsity.max_num_active_channels)) for unit_index in range(num_units): template = templates_array[unit_index, ...] @@ -35,6 +36,7 @@ def generate_test_template(template_type): sampling_frequency=sampling_frequency, nbefore=nbefore, ) + elif template_type == "sparse_with_dense_templates": # sparse with dense templates sparsity_mask = np.array([[True, False, True], [False, True, False]]) @@ -46,7 +48,7 @@ def generate_test_template(template_type): ) -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_pickle_serialization(template_type, tmp_path): template = generate_test_template(template_type) @@ -62,7 +64,7 @@ def test_pickle_serialization(template_type, tmp_path): assert template == template_reloaded -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_json_serialization(template_type): template = generate_test_template(template_type) @@ -72,14 +74,14 @@ def test_json_serialization(template_type): assert template == template_reloaded_from_json -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_get_dense_templates(template_type): template = generate_test_template(template_type) dense_templates = template.get_dense_templates() assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_get_sparse_templates(template_type): template = generate_test_template(template_type) @@ -94,11 +96,8 @@ def test_get_sparse_templates(template_type): template.sparsity.max_num_active_channels, ) assert template.are_templates_sparse() - elif template_type == "sparse_with_dense_templates": - sparse_templates = template.get_sparse_templates() - assert sparse_templates.shape == ( - template.num_units, - template.num_samples, - template.sparsity.max_num_active_channels, - ) - assert not template.are_templates_sparse() + + +def test_initialization_fail_with_dense_templates(): + with pytest.raises(ValueError, match="Sparsity mask passed but the templates are not sparse"): + template = generate_test_template(template_type="sparse_with_dense_templates") From cc8a5236e00072985cac399e078877c597b87ecd Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:27:42 +0200 Subject: [PATCH 11/21] add docstring --- src/spikeinterface/core/template.py | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index c0d4869d5e..dc6e0a5070 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -6,6 +6,41 @@ @dataclass(kw_only=True) class Templates: + """ + A class to represent spike templates, which can be either dense or sparse. + + Attributes + ---------- + templates_array : np.ndarray + Array containing the templates data. + sampling_frequency : float + Sampling frequency of the templates. + nbefore : int + Number of samples before the spike peak. + sparsity_mask : np.ndarray, optional + Binary array indicating the sparsity pattern of the templates. + If `None`, the templates are considered dense. + channel_ids : np.ndarray, optional + Array of channel IDs. If `None`, defaults to an array of increasing integers. + unit_ids : np.ndarray, optional + Array of unit IDs. If `None`, defaults to an array of increasing integers. + num_units : int + Number of units in the templates. Automatically determined from `templates_array`. + num_samples : int + Number of samples per template. Automatically determined from `templates_array`. + num_channels : int + Number of channels in the templates. Automatically determined from `templates_array` or `sparsity_mask`. + nafter : int + Number of samples after the spike peak. Calculated as `num_samples - nbefore - 1`. + ms_before : float + Milliseconds before the spike peak. Calculated from `nbefore` and `sampling_frequency`. + ms_after : float + Milliseconds after the spike peak. Calculated from `nafter` and `sampling_frequency`. + sparsity : ChannelSparsity, optional + Object representing the sparsity pattern of the templates. Calculated from `sparsity_mask`. + If `None`, the templates are considered dense. + """ + templates_array: np.ndarray sampling_frequency: float nbefore: int From 73e95627b9015d20addd78e7b59f697ce5acb335 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:34:13 +0200 Subject: [PATCH 12/21] alessio remark about nafter definition --- src/spikeinterface/core/sparsity.py | 3 +-- src/spikeinterface/core/template.py | 12 +++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index f2da16b757..1593b6c9e4 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -208,9 +208,8 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo # If any channel is non-zero outside of the active channels, then the waveforms are not sparse excess_zeros = waveforms[..., num_active_channels:].sum() - are_sparse = excess_zeros == 0 - return are_sparse + return int(excess_zeros) == 0 @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index dc6e0a5070..bc4f7bae80 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -64,7 +64,9 @@ def __post_init__(self): self.num_channels = self.templates_array.shape[2] else: self.num_channels = self.sparsity_mask.shape[1] - self.nafter = self.num_samples - self.nbefore - 1 + + # Time and frames domain information + self.nafter = self.num_samples - self.nbefore self.ms_before = self.nbefore / self.sampling_frequency * 1000 self.ms_after = self.nafter / self.sampling_frequency * 1000 @@ -110,8 +112,8 @@ def get_dense_templates(self) -> np.ndarray: if self.sparsity is None: return self.templates_array - dense_shape = (self.num_units, self.num_samples, self.num_channels) - dense_waveforms = np.zeros(dense=dense_shape, dtype=self.templates_array.dtype) + densified_shape = (self.num_units, self.num_samples, self.num_channels) + dense_waveforms = np.zeros(dense=densified_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] @@ -125,8 +127,8 @@ def get_sparse_templates(self) -> np.ndarray: raise ValueError("Can't return sparse templates without passing a sparsity mask") max_num_active_channels = self.sparsity.max_num_active_channels - sparse_shape = (self.num_units, self.num_samples, max_num_active_channels) - sparse_waveforms = np.zeros(shape=sparse_shape, dtype=self.templates_array.dtype) + sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparse_waveforms = np.zeros(shape=sparisfied_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) From 52c333b1a5bdcc5975b2c67c2341f02e80ca74de Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:50:25 +0200 Subject: [PATCH 13/21] fix mistake --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index bc4f7bae80..e8c0f83f50 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -113,7 +113,7 @@ def get_dense_templates(self) -> np.ndarray: return self.templates_array densified_shape = (self.num_units, self.num_samples, self.num_channels) - dense_waveforms = np.zeros(dense=densified_shape, dtype=self.templates_array.dtype) + dense_waveforms = np.zeros(shape=densified_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] From 2fb79ccf7fe3ad6a1efe5675171a82e54894aedd Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 09:31:01 +0200 Subject: [PATCH 14/21] Update src/spikeinterface/core/sparsity.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/sparsity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 1593b6c9e4..0a8c165ba5 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -186,8 +186,8 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nda if not self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): error_message = ( - "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " - f"{num_active_channels} but the waveform has non zero values outsies of those active channels: \n" + "Waveforms do not seem to be in the sparsity shape for this unit_id. The number of active channels is " + f"{num_active_channels}, but the waveform has non-zero values outsies of those active channels: \n" f"{waveforms[..., num_active_channels:]}" ) raise ValueError(error_message) From 437695c8b5c2f04331b0eac3cc7c876697b0709a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 10:01:47 +0200 Subject: [PATCH 15/21] changes --- src/spikeinterface/core/sparsity.py | 10 ++++ src/spikeinterface/core/template.py | 71 +++++++++++++---------------- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 1593b6c9e4..990687ca04 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -211,6 +211,16 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo return int(excess_zeros) == 0 + def sparisfy_templates(self, templates_array: np.ndarray) -> np.ndarray: + max_num_active_channels = self.max_num_active_channels + sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparse_templates = np.zeros(shape=sparisfied_shape, dtype=templates_array.dtype) + for unit_index, unit_id in enumerate(self.unit_ids): + template = templates_array[unit_index, ...] + sparse_templates[unit_index, ...] = self.sparsify_waveforms(waveforms=template, unit_id=unit_id) + + return sparse_templates + @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): """ diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index e8c0f83f50..ed71b6d2ea 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -4,7 +4,7 @@ from .sparsity import ChannelSparsity -@dataclass(kw_only=True) +@dataclass class Templates: """ A class to represent spike templates, which can be either dense or sparse. @@ -18,7 +18,7 @@ class Templates: nbefore : int Number of samples before the spike peak. sparsity_mask : np.ndarray, optional - Binary array indicating the sparsity pattern of the templates. + Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. channel_ids : np.ndarray, optional Array of channel IDs. If `None`, defaults to an array of increasing integers. @@ -49,6 +49,8 @@ class Templates: channel_ids: np.ndarray = None unit_ids: np.ndarray = None + check_template_array_and_sparsity_mask_are_consistentency: bool = True + num_units: int = field(init=False) num_samples: int = field(init=False) num_channels: int = field(init=False) @@ -83,29 +85,9 @@ def __post_init__(self): ) # Test that the templates are sparse if a sparsity mask is passed - if not self._are_passed_templates_sparse(): - raise ValueError("Sparsity mask passed but the templates are not sparse") - - def to_dict(self): - return { - "templates_array": self.templates_array.tolist(), - "sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask.tolist(), - "channel_ids": self.channel_ids.tolist(), - "unit_ids": self.unit_ids.tolist(), - "sampling_frequency": self.sampling_frequency, - "nbefore": self.nbefore, - } - - @classmethod - def from_dict(cls, data): - return cls( - templates_array=np.array(data["templates_array"]), - sparsity_mask=None if data["sparsity_mask"] is None else np.array(data["sparsity_mask"]), - channel_ids=np.array(data["channel_ids"]), - unit_ids=np.array(data["unit_ids"]), - sampling_frequency=data["sampling_frequency"], - nbefore=data["nbefore"], - ) + if self.check_template_array_and_sparsity_mask_are_consistentency: + if not self._are_passed_templates_sparse(): + raise ValueError("Sparsity mask passed but the templates are not sparse") def get_dense_templates(self) -> np.ndarray: # Assumes and object without a sparsity mask already has dense templates @@ -121,20 +103,6 @@ def get_dense_templates(self) -> np.ndarray: return dense_waveforms - def get_sparse_templates(self) -> np.ndarray: - # Objects without sparsity mask don't have sparsity and therefore can't return sparse templates - if self.sparsity is None: - raise ValueError("Can't return sparse templates without passing a sparsity mask") - - max_num_active_channels = self.sparsity.max_num_active_channels - sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) - sparse_waveforms = np.zeros(shape=sparisfied_shape, dtype=self.templates_array.dtype) - for unit_index, unit_id in enumerate(self.unit_ids): - waveforms = self.templates_array[unit_index, ...] - sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) - - return sparse_waveforms - def are_templates_sparse(self) -> bool: return self.sparsity is not None @@ -151,8 +119,31 @@ def _are_passed_templates_sparse(self) -> bool: return are_templates_sparse + def to_dict(self): + return { + "templates_array": self.templates_array, + "sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask, + "channel_ids": self.channel_ids, + "unit_ids": self.unit_ids, + "sampling_frequency": self.sampling_frequency, + "nbefore": self.nbefore, + } + + @classmethod + def from_dict(cls, data): + return cls( + templates_array=np.asarray(data["templates_array"]), + sparsity_mask=None if data["sparsity_mask"] is None else np.asarray(data["sparsity_mask"]), + channel_ids=np.asarray(data["channel_ids"]), + unit_ids=np.asarray(data["unit_ids"]), + sampling_frequency=data["sampling_frequency"], + nbefore=data["nbefore"], + ) + def to_json(self): - return json.dumps(self.to_dict()) + from spikeinterface.core.core_tools import SIJsonEncoder + + return json.dumps(self.to_dict(), cls=SIJsonEncoder) @classmethod def from_json(cls, json_str): From 600f20f9b7465f7f398097207d036e8ee4ff8d92 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 10:29:30 +0200 Subject: [PATCH 16/21] modify docstring --- src/spikeinterface/core/template.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index ed71b6d2ea..909d47acfc 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -9,7 +9,7 @@ class Templates: """ A class to represent spike templates, which can be either dense or sparse. - Attributes + It is constructed with the following parameters: ---------- templates_array : np.ndarray Array containing the templates data. @@ -17,13 +17,21 @@ class Templates: Sampling frequency of the templates. nbefore : int Number of samples before the spike peak. - sparsity_mask : np.ndarray, optional + sparsity_mask : np.ndarray, optional (default=None) Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. - channel_ids : np.ndarray, optional + channel_ids : np.ndarray, optional (default=None) Array of channel IDs. If `None`, defaults to an array of increasing integers. - unit_ids : np.ndarray, optional + unit_ids : np.ndarray, optional (default=None) Array of unit IDs. If `None`, defaults to an array of increasing integers. + check_for_consistent_sparsity : bool, optional (default=True) + When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the + structure fo the sparsity_masl. + + The following attributes are avaialble after construction: + + Attributes + ---------- num_units : int Number of units in the templates. Automatically determined from `templates_array`. num_samples : int @@ -49,7 +57,7 @@ class Templates: channel_ids: np.ndarray = None unit_ids: np.ndarray = None - check_template_array_and_sparsity_mask_are_consistentency: bool = True + check_for_consistent_sparsity: bool = True num_units: int = field(init=False) num_samples: int = field(init=False) @@ -85,7 +93,7 @@ def __post_init__(self): ) # Test that the templates are sparse if a sparsity mask is passed - if self.check_template_array_and_sparsity_mask_are_consistentency: + if self.check_for_consistent_sparsity: if not self._are_passed_templates_sparse(): raise ValueError("Sparsity mask passed but the templates are not sparse") From a1e6eaec457a55c6043bedcc1349176aa4e57f0c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 10:34:21 +0200 Subject: [PATCH 17/21] remove tests for get_sparse_templates --- .../core/tests/test_template_class.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index b1244ab0d1..40bb3f2b34 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -81,23 +81,6 @@ def test_get_dense_templates(template_type): assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) -@pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_get_sparse_templates(template_type): - template = generate_test_template(template_type) - - if template_type == "dense": - with pytest.raises(ValueError): - sparse_templates = template.get_sparse_templates() - elif template_type == "sparse": - sparse_templates = template.get_sparse_templates() - assert sparse_templates.shape == ( - template.num_units, - template.num_samples, - template.sparsity.max_num_active_channels, - ) - assert template.are_templates_sparse() - - def test_initialization_fail_with_dense_templates(): with pytest.raises(ValueError, match="Sparsity mask passed but the templates are not sparse"): template = generate_test_template(template_type="sparse_with_dense_templates") From aa08f1ba85aea5eec5835de04618061aa9df244b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 12:18:09 +0200 Subject: [PATCH 18/21] Update src/spikeinterface/core/template.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 909d47acfc..8ebd0a75f5 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -28,7 +28,7 @@ class Templates: When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the structure fo the sparsity_masl. - The following attributes are avaialble after construction: + The following attributes are available after construction: Attributes ---------- From ea2a8a03c43b3ca444e02e50b837b1d6b7a51bd9 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 12:18:32 +0200 Subject: [PATCH 19/21] Update src/spikeinterface/core/template.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 8ebd0a75f5..e6556a68f7 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -9,7 +9,7 @@ class Templates: """ A class to represent spike templates, which can be either dense or sparse. - It is constructed with the following parameters: + Parameters ---------- templates_array : np.ndarray Array containing the templates data. From 4f8bd73fd08798671fd252ad9a23ea40e12ef7e3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 2 Nov 2023 11:08:11 +0100 Subject: [PATCH 20/21] Update src/spikeinterface/core/template.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index e6556a68f7..8beb6b46b1 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -17,7 +17,7 @@ class Templates: Sampling frequency of the templates. nbefore : int Number of samples before the spike peak. - sparsity_mask : np.ndarray, optional (default=None) + sparsity_mask : np.ndarray or None, default: None Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. channel_ids : np.ndarray, optional (default=None) From e52dd28d7fbc512e98bdbb28f4db32930b1cd73f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 2 Nov 2023 11:09:00 +0100 Subject: [PATCH 21/21] docstring compliance --- src/spikeinterface/core/template.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 8beb6b46b1..e6372c7082 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -20,11 +20,11 @@ class Templates: sparsity_mask : np.ndarray or None, default: None Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. - channel_ids : np.ndarray, optional (default=None) + channel_ids : np.ndarray, optional default: None Array of channel IDs. If `None`, defaults to an array of increasing integers. - unit_ids : np.ndarray, optional (default=None) + unit_ids : np.ndarray, optional default: None Array of unit IDs. If `None`, defaults to an array of increasing integers. - check_for_consistent_sparsity : bool, optional (default=True) + check_for_consistent_sparsity : bool, optional default: None When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the structure fo the sparsity_masl.