Skip to content

Commit

Permalink
added checks
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Sep 20, 2023
1 parent 3f1a043 commit 7085e77
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
21 changes: 19 additions & 2 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def __init__(self, mask, unit_ids, channel_ids):
self.max_channel_representation = self.mask.sum(axis=1).max()

def __repr__(self):
sparsity = 1 - np.mean(self.mask)
txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - sparsity, P(x=0): {sparsity:0.2f}"
density = np.mean(self.mask)
txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}"
return txt

@property
Expand Down Expand Up @@ -147,6 +147,7 @@ def sparsify_waveforms(self, waveforms: np.ndarray, unit_index: int) -> np.ndarr
Sparse waveforms with shape (num_units, num_samples, num_active_channels).
"""

assert self.are_waveforms_dense(waveforms=waveforms), "Waveforms must be dense to sparsify them."
unit_id = self.unit_ids[unit_index]
non_zero_indices = self.unit_id_to_channel_indices[unit_id]
num_sparse_channels = len(non_zero_indices)
Expand Down Expand Up @@ -176,14 +177,30 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_index: int) -> np.ndarra
The densified waveforms array of shape (num_units, num_samples, num_channels).
"""

unit_id = self.unit_ids[unit_index]
non_zero_indices = self.unit_id_to_channel_indices[unit_id]

assert_msg = (
"Waveforms must be sparse in this index to densify them. The sparsity for this unit index is "
f"{len(non_zero_indices)} but the waveform has sparsity (last dimension) of {waveforms.shape[-1]}."
)
assert self.are_waveforms_sparse(waveforms=waveforms, unit_index=unit_index), assert_msg

densified_shape = waveforms.shape[:-1] + (self.num_channels,)
densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype)
densified_waveforms[..., non_zero_indices] = waveforms[...]

return densified_waveforms

def are_waveforms_dense(self, waveforms: np.ndarray) -> bool:
return waveforms.shape[-1] == self.num_channels

def are_waveforms_sparse(self, waveforms: np.ndarray, unit_index: int) -> bool:
unit_id = self.unit_ids[unit_index]
non_zero_indices = self.unit_id_to_channel_indices[unit_id]
return waveforms.shape[-1] == len(non_zero_indices)

@classmethod
def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids):
"""
Expand Down
8 changes: 7 additions & 1 deletion src/spikeinterface/core/tests/test_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,15 @@ def test_sparsify_waveforms():
for unit_index in range(num_units):
waveforms_dense = rng.random(size=(num_units, num_samples, num_channels))

# Test are_waveforms_dense
assert sparsity.are_waveforms_dense(waveforms_dense)

# Test sparsify
waveforms_sparse = sparsity.sparsify_waveforms(waveforms_dense, unit_index=unit_index)
num_sparse_channels = sparsity.mask[unit_index, :].sum()
assert waveforms_sparse.shape == (num_units, num_samples, num_sparse_channels)

# Test round-trip (note this is loosy)
# Test round-trip (note that this is loosy)
unit_id = unit_ids[unit_index]
non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id]
waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_index=unit_index)
Expand Down Expand Up @@ -119,6 +122,9 @@ def test_densify_waveforms():
num_sparse_channels = len(non_zero_indices)
waveforms_sparse = rng.random(size=(num_units, num_samples, num_sparse_channels))

# Test are waveforms sparse
assert sparsity.are_waveforms_sparse(waveforms_sparse, unit_index=unit_index)

# Test densify
waveforms_dense = sparsity.densify_waveforms(waveforms_sparse, unit_index=unit_index)
assert waveforms_dense.shape == (num_units, num_samples, num_channels)
Expand Down

0 comments on commit 7085e77

Please sign in to comment.