From 7085e779b0b4d3f86cf444d06a08a0b4ecaa0d56 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 13:23:32 +0200 Subject: [PATCH] added checks --- src/spikeinterface/core/sparsity.py | 21 +++++++++++++++++-- .../core/tests/test_sparsity.py | 8 ++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 834b9bf8a8..3111132674 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -103,8 +103,8 @@ def __init__(self, mask, unit_ids, channel_ids): self.max_channel_representation = self.mask.sum(axis=1).max() def __repr__(self): - sparsity = 1 - np.mean(self.mask) - txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - sparsity, P(x=0): {sparsity:0.2f}" + density = np.mean(self.mask) + txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}" return txt @property @@ -147,6 +147,7 @@ def sparsify_waveforms(self, waveforms: np.ndarray, unit_index: int) -> np.ndarr Sparse waveforms with shape (num_units, num_samples, num_active_channels). """ + assert self.are_waveforms_dense(waveforms=waveforms), "Waveforms must be dense to sparsify them." unit_id = self.unit_ids[unit_index] non_zero_indices = self.unit_id_to_channel_indices[unit_id] num_sparse_channels = len(non_zero_indices) @@ -176,14 +177,30 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_index: int) -> np.ndarra The densified waveforms array of shape (num_units, num_samples, num_channels). """ + unit_id = self.unit_ids[unit_index] non_zero_indices = self.unit_id_to_channel_indices[unit_id] + + assert_msg = ( + "Waveforms must be sparse in this index to densify them. The sparsity for this unit index is " + f"{len(non_zero_indices)} but the waveform has sparsity (last dimension) of {waveforms.shape[-1]}." + ) + assert self.are_waveforms_sparse(waveforms=waveforms, unit_index=unit_index), assert_msg + densified_shape = waveforms.shape[:-1] + (self.num_channels,) densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype) densified_waveforms[..., non_zero_indices] = waveforms[...] return densified_waveforms + def are_waveforms_dense(self, waveforms: np.ndarray) -> bool: + return waveforms.shape[-1] == self.num_channels + + def are_waveforms_sparse(self, waveforms: np.ndarray, unit_index: int) -> bool: + unit_id = self.unit_ids[unit_index] + non_zero_indices = self.unit_id_to_channel_indices[unit_id] + return waveforms.shape[-1] == len(non_zero_indices) + @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): """ diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 8bc5574cbf..c3a1c378f7 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -75,12 +75,15 @@ def test_sparsify_waveforms(): for unit_index in range(num_units): waveforms_dense = rng.random(size=(num_units, num_samples, num_channels)) + # Test are_waveforms_dense + assert sparsity.are_waveforms_dense(waveforms_dense) + # Test sparsify waveforms_sparse = sparsity.sparsify_waveforms(waveforms_dense, unit_index=unit_index) num_sparse_channels = sparsity.mask[unit_index, :].sum() assert waveforms_sparse.shape == (num_units, num_samples, num_sparse_channels) - # Test round-trip (note this is loosy) + # Test round-trip (note that this is loosy) unit_id = unit_ids[unit_index] non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_index=unit_index) @@ -119,6 +122,9 @@ def test_densify_waveforms(): num_sparse_channels = len(non_zero_indices) waveforms_sparse = rng.random(size=(num_units, num_samples, num_sparse_channels)) + # Test are waveforms sparse + assert sparsity.are_waveforms_sparse(waveforms_sparse, unit_index=unit_index) + # Test densify waveforms_dense = sparsity.densify_waveforms(waveforms_sparse, unit_index=unit_index) assert waveforms_dense.shape == (num_units, num_samples, num_channels)