Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Template class #1982

Merged
merged 34 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
513a344
add basic instance and numpy behavior
h-mayorquin Sep 12, 2023
c242446
add pickability
h-mayorquin Sep 12, 2023
4ca9ec6
add json test
h-mayorquin Sep 12, 2023
107bdf9
test fancy indices
h-mayorquin Sep 12, 2023
1585f11
Merge branch 'main' into add_template_data_class
h-mayorquin Sep 13, 2023
078e605
Merge branch 'main' into add_template_data_class
h-mayorquin Sep 13, 2023
383e7d8
Merge branch 'main' into add_template_data_class
h-mayorquin Sep 20, 2023
aae0ef4
Merge branch 'main' into add_template_data_class
h-mayorquin Sep 20, 2023
d2c6ec7
alessio and samuel requests
h-mayorquin Sep 20, 2023
961d269
remove slicing
h-mayorquin Sep 20, 2023
ad4185e
Merge remote-tracking branch 'refs/remotes/me/add_template_data_class…
h-mayorquin Sep 20, 2023
9ee3a1d
passing tests
h-mayorquin Sep 20, 2023
9d7c9ac
add densification and sparsification methods
h-mayorquin Sep 20, 2023
e4e5736
Merge branch 'main' into add_template_data_class
h-mayorquin Sep 25, 2023
6e1027b
adding tests for sparsity and density
h-mayorquin Sep 25, 2023
cb8107a
Merge remote-tracking branch 'refs/remotes/me/add_template_data_class…
h-mayorquin Sep 25, 2023
d05e67d
prohibit dense templates when passing sparsity mask
h-mayorquin Sep 28, 2023
cc8a523
add docstring
h-mayorquin Sep 28, 2023
13d7e9f
Merge branch 'main' into add_template_data_class
h-mayorquin Sep 28, 2023
73e9562
alessio remark about nafter definition
h-mayorquin Sep 28, 2023
52c333b
fix mistake
h-mayorquin Sep 28, 2023
3368b0a
Merge branch 'main' into add_template_data_class
h-mayorquin Sep 28, 2023
2fb79cc
Update src/spikeinterface/core/sparsity.py
h-mayorquin Oct 24, 2023
437695c
changes
h-mayorquin Oct 24, 2023
c2727ee
Merge remote-tracking branch 'refs/remotes/me/add_template_data_class…
h-mayorquin Oct 24, 2023
600f20f
modify docstring
h-mayorquin Oct 24, 2023
8b78386
Merge branch 'main' into add_template_data_class
h-mayorquin Oct 24, 2023
a1e6eae
remove tests for get_sparse_templates
h-mayorquin Oct 24, 2023
3c38d15
Merge remote-tracking branch 'refs/remotes/me/add_template_data_class…
h-mayorquin Oct 24, 2023
aa08f1b
Update src/spikeinterface/core/template.py
h-mayorquin Oct 24, 2023
ea2a8a0
Update src/spikeinterface/core/template.py
h-mayorquin Oct 24, 2023
afc04da
Merge branch 'main' into add_template_data_class
h-mayorquin Nov 2, 2023
4f8bd73
Update src/spikeinterface/core/template.py
h-mayorquin Nov 2, 2023
e52dd28
docstring compliance
h-mayorquin Nov 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,8 @@ def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nd
or a single sparsified waveform (template) with shape (num_samples, num_active_channels).
"""

assert_msg = (
"Waveforms must be dense to sparsify them. "
f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}"
)
assert self.are_waveforms_dense(waveforms=waveforms), assert_msg
if self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id):
return waveforms

non_zero_indices = self.unit_id_to_channel_indices[unit_id]
sparsified_waveforms = waveforms[..., non_zero_indices]
Expand Down Expand Up @@ -185,16 +182,20 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nda
"""

non_zero_indices = self.unit_id_to_channel_indices[unit_id]
num_active_channels = len(non_zero_indices)

assert_msg = (
"Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is "
f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels."
)
assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg
if not self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id):
error_message = (
"Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is "
f"{num_active_channels} but the waveform has non zero values outsies of those active channels: \n"
f"{waveforms[..., num_active_channels:]}"
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
)
raise ValueError(error_message)

densified_shape = waveforms.shape[:-1] + (self.num_channels,)
densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype)
densified_waveforms[..., non_zero_indices] = waveforms
densified_waveforms = np.zeros(shape=densified_shape, dtype=waveforms.dtype)
# Maps the active channels to their original indices
densified_waveforms[..., non_zero_indices] = waveforms[..., :num_active_channels]

return densified_waveforms

Expand All @@ -204,7 +205,11 @@ def are_waveforms_dense(self, waveforms: np.ndarray) -> bool:
def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> bool:
non_zero_indices = self.unit_id_to_channel_indices[unit_id]
num_active_channels = len(non_zero_indices)
return waveforms.shape[-1] == num_active_channels

# If any channel is non-zero outside of the active channels, then the waveforms are not sparse
excess_zeros = waveforms[..., num_active_channels:].sum()

return int(excess_zeros) == 0
Comment on lines +212 to +216
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure to be happy with this.
This is computationaly costy. For me the shape is enough.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok for now. If this becomes problematic later we'll fix it.

NOTE: this assumes that the extra channels are zero-padded!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a boolean to deactivate this.


@classmethod
def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids):
Expand Down
197 changes: 197 additions & 0 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import numpy as np
import json
from dataclasses import dataclass, field, astuple
from .sparsity import ChannelSparsity


@dataclass(kw_only=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One other comment based on my googling kw_only is a 3.10+ ( I think 3.8 and 3.9 are suppose to be supported).
https://docs.python.org/3/library/dataclasses.html

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@h-mayorquin can you remove it then?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.
This is another shortcoming of only having tests for python 3.10 only.

class Templates:
"""
A class to represent spike templates, which can be either dense or sparse.

Attributes
----------
templates_array : np.ndarray
Array containing the templates data.
sampling_frequency : float
Sampling frequency of the templates.
nbefore : int
Number of samples before the spike peak.
sparsity_mask : np.ndarray, optional
Binary array indicating the sparsity pattern of the templates.
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
If `None`, the templates are considered dense.
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
channel_ids : np.ndarray, optional
Array of channel IDs. If `None`, defaults to an array of increasing integers.
unit_ids : np.ndarray, optional
Array of unit IDs. If `None`, defaults to an array of increasing integers.
num_units : int
Number of units in the templates. Automatically determined from `templates_array`.
num_samples : int
Number of samples per template. Automatically determined from `templates_array`.
num_channels : int
Number of channels in the templates. Automatically determined from `templates_array` or `sparsity_mask`.
nafter : int
Number of samples after the spike peak. Calculated as `num_samples - nbefore - 1`.
ms_before : float
Milliseconds before the spike peak. Calculated from `nbefore` and `sampling_frequency`.
ms_after : float
Milliseconds after the spike peak. Calculated from `nafter` and `sampling_frequency`.
sparsity : ChannelSparsity, optional
Object representing the sparsity pattern of the templates. Calculated from `sparsity_mask`.
If `None`, the templates are considered dense.
"""

templates_array: np.ndarray
sampling_frequency: float
nbefore: int

sparsity_mask: np.ndarray = None
channel_ids: np.ndarray = None
unit_ids: np.ndarray = None
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

num_units: int = field(init=False)
num_samples: int = field(init=False)
num_channels: int = field(init=False)

nafter: int = field(init=False)
ms_before: float = field(init=False)
ms_after: float = field(init=False)
sparsity: ChannelSparsity = field(init=False, default=None)

def __post_init__(self):
self.num_units, self.num_samples = self.templates_array.shape[:2]
if self.sparsity_mask is None:
self.num_channels = self.templates_array.shape[2]
else:
self.num_channels = self.sparsity_mask.shape[1]

# Time and frames domain information
self.nafter = self.num_samples - self.nbefore
self.ms_before = self.nbefore / self.sampling_frequency * 1000
self.ms_after = self.nafter / self.sampling_frequency * 1000

# Initialize sparsity object
if self.channel_ids is None:
self.channel_ids = np.arange(self.num_channels)
if self.unit_ids is None:
self.unit_ids = np.arange(self.num_units)
if self.sparsity_mask is not None:
self.sparsity = ChannelSparsity(
mask=self.sparsity_mask,
unit_ids=self.unit_ids,
channel_ids=self.channel_ids,
)

# Test that the templates are sparse if a sparsity mask is passed
if not self._are_passed_templates_sparse():
raise ValueError("Sparsity mask passed but the templates are not sparse")

def to_dict(self):
return {
"templates_array": self.templates_array.tolist(),
"sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask.tolist(),
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
"channel_ids": self.channel_ids.tolist(),
"unit_ids": self.unit_ids.tolist(),
"sampling_frequency": self.sampling_frequency,
"nbefore": self.nbefore,
}

@classmethod
def from_dict(cls, data):
return cls(
templates_array=np.array(data["templates_array"]),
sparsity_mask=None if data["sparsity_mask"] is None else np.array(data["sparsity_mask"]),
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
channel_ids=np.array(data["channel_ids"]),
unit_ids=np.array(data["unit_ids"]),
sampling_frequency=data["sampling_frequency"],
nbefore=data["nbefore"],
)

def get_dense_templates(self) -> np.ndarray:
# Assumes and object without a sparsity mask already has dense templates
if self.sparsity is None:
return self.templates_array

densified_shape = (self.num_units, self.num_samples, self.num_channels)
dense_waveforms = np.zeros(shape=densified_shape, dtype=self.templates_array.dtype)

for unit_index, unit_id in enumerate(self.unit_ids):
waveforms = self.templates_array[unit_index, ...]
dense_waveforms[unit_index, ...] = self.sparsity.densify_waveforms(waveforms=waveforms, unit_id=unit_id)

return dense_waveforms

def get_sparse_templates(self) -> np.ndarray:
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
# Objects without sparsity mask don't have sparsity and therefore can't return sparse templates
if self.sparsity is None:
raise ValueError("Can't return sparse templates without passing a sparsity mask")

max_num_active_channels = self.sparsity.max_num_active_channels
sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels)
sparse_waveforms = np.zeros(shape=sparisfied_shape, dtype=self.templates_array.dtype)
for unit_index, unit_id in enumerate(self.unit_ids):
waveforms = self.templates_array[unit_index, ...]
sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id)

return sparse_waveforms

def are_templates_sparse(self) -> bool:
return self.sparsity is not None

def _are_passed_templates_sparse(self) -> bool:
"""
Tests if the templates passed to the init constructor are sparse
"""
are_templates_sparse = True
for unit_index, unit_id in enumerate(self.unit_ids):
waveforms = self.templates_array[unit_index, ...]
are_templates_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id)
if not are_templates_sparse:
return False

return are_templates_sparse

def to_json(self):
return json.dumps(self.to_dict())
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def from_json(cls, json_str):
return cls.from_dict(json.loads(json_str))

def __eq__(self, other):
"""
Necessary to compare templates because they naturally compare objects by equality of their fields
which is not possible for numpy arrays. Therefore, we override the __eq__ method to compare each numpy arrays
using np.array_equal instead
"""
if not isinstance(other, Templates):
return False

# Convert the instances to tuples
self_tuple = astuple(self)
other_tuple = astuple(other)

# Compare each field
for s_field, o_field in zip(self_tuple, other_tuple):
if isinstance(s_field, np.ndarray):
if not np.array_equal(s_field, o_field):
return False

# Compare ChannelSparsity by its mask, unit_ids and channel_ids.
# Maybe ChannelSparsity should have its own __eq__ method
elif isinstance(s_field, ChannelSparsity):
if not isinstance(o_field, ChannelSparsity):
return False

# Compare ChannelSparsity by its mask, unit_ids and channel_ids
if not np.array_equal(s_field.mask, o_field.mask):
return False
if not np.array_equal(s_field.unit_ids, o_field.unit_ids):
return False
if not np.array_equal(s_field.channel_ids, o_field.channel_ids):
return False
else:
if s_field != o_field:
return False

return True
103 changes: 103 additions & 0 deletions src/spikeinterface/core/tests/test_template_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
import numpy as np
import pickle
from spikeinterface.core.template import Templates
from spikeinterface.core.sparsity import ChannelSparsity


def generate_test_template(template_type):
num_units = 2
num_samples = 5
num_channels = 3
templates_shape = (num_units, num_samples, num_channels)
templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape)

sampling_frequency = 30_000
nbefore = 2

if template_type == "dense":
return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore)
elif template_type == "sparse": # sparse with sparse templates
sparsity_mask = np.array([[True, False, True], [False, True, False]])
sparsity = ChannelSparsity(
mask=sparsity_mask, unit_ids=np.arange(num_units), channel_ids=np.arange(num_channels)
)

# Create sparse templates
sparse_templates_array = np.zeros(shape=(num_units, num_samples, sparsity.max_num_active_channels))
for unit_index in range(num_units):
template = templates_array[unit_index, ...]
sparse_template = sparsity.sparsify_waveforms(waveforms=template, unit_id=unit_index)
sparse_templates_array[unit_index, :, : sparse_template.shape[1]] = sparse_template

return Templates(
templates_array=sparse_templates_array,
sparsity_mask=sparsity_mask,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
)

elif template_type == "sparse_with_dense_templates": # sparse with dense templates
sparsity_mask = np.array([[True, False, True], [False, True, False]])

return Templates(
templates_array=templates_array,
sparsity_mask=sparsity_mask,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
)


@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_pickle_serialization(template_type, tmp_path):
template = generate_test_template(template_type)

# Dump to pickle
pkl_path = tmp_path / "templates.pkl"
with open(pkl_path, "wb") as f:
pickle.dump(template, f)

# Load from pickle
with open(pkl_path, "rb") as f:
template_reloaded = pickle.load(f)

assert template == template_reloaded


@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_json_serialization(template_type):
template = generate_test_template(template_type)

json_str = template.to_json()
template_reloaded_from_json = Templates.from_json(json_str)

assert template == template_reloaded_from_json


@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_get_dense_templates(template_type):
template = generate_test_template(template_type)
dense_templates = template.get_dense_templates()
assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels)


@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_get_sparse_templates(template_type):
template = generate_test_template(template_type)

if template_type == "dense":
with pytest.raises(ValueError):
sparse_templates = template.get_sparse_templates()
elif template_type == "sparse":
sparse_templates = template.get_sparse_templates()
assert sparse_templates.shape == (
template.num_units,
template.num_samples,
template.sparsity.max_num_active_channels,
)
assert template.are_templates_sparse()


def test_initialization_fail_with_dense_templates():
with pytest.raises(ValueError, match="Sparsity mask passed but the templates are not sparse"):
template = generate_test_template(template_type="sparse_with_dense_templates")
Comment on lines +84 to +86
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@h-mayorquin in this line, would it make sense to have a Templates.make_sparse(sparsity) method which returns a Templates object with sparse templates using the sparsity? (not only the template_array)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like an overkill to me. The only thing that is different is the template_array.

If you want to make it simply and unintrusive for users we shouid sparsify ourselves at __init__ instead of throwing an error.
If you want to make it simpler but want to let them know we should have a method to sparsify the template array and tell them to use that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. So the current way to do that would be:

templates_dense_array = ...
sparsity = ...

templates_sparse_array = sparsity.sparsify_waveforms(templates_dense)

templates_sparse = Templated(templates_sparse_array, sparsity.mask)

Indeed, looks simple enough :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is now part of the sparsity class.