Skip to content

Commit

Permalink
alessio remark about nafter definition
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Sep 28, 2023
1 parent 13d7e9f commit 73e9562
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
3 changes: 1 addition & 2 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,8 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo

# If any channel is non-zero outside of the active channels, then the waveforms are not sparse
excess_zeros = waveforms[..., num_active_channels:].sum()
are_sparse = excess_zeros == 0

return are_sparse
return int(excess_zeros) == 0

@classmethod
def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids):
Expand Down
12 changes: 7 additions & 5 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def __post_init__(self):
self.num_channels = self.templates_array.shape[2]
else:
self.num_channels = self.sparsity_mask.shape[1]
self.nafter = self.num_samples - self.nbefore - 1

# Time and frames domain information
self.nafter = self.num_samples - self.nbefore
self.ms_before = self.nbefore / self.sampling_frequency * 1000
self.ms_after = self.nafter / self.sampling_frequency * 1000

Expand Down Expand Up @@ -110,8 +112,8 @@ def get_dense_templates(self) -> np.ndarray:
if self.sparsity is None:
return self.templates_array

dense_shape = (self.num_units, self.num_samples, self.num_channels)
dense_waveforms = np.zeros(dense=dense_shape, dtype=self.templates_array.dtype)
densified_shape = (self.num_units, self.num_samples, self.num_channels)
dense_waveforms = np.zeros(dense=densified_shape, dtype=self.templates_array.dtype)

for unit_index, unit_id in enumerate(self.unit_ids):
waveforms = self.templates_array[unit_index, ...]
Expand All @@ -125,8 +127,8 @@ def get_sparse_templates(self) -> np.ndarray:
raise ValueError("Can't return sparse templates without passing a sparsity mask")

max_num_active_channels = self.sparsity.max_num_active_channels
sparse_shape = (self.num_units, self.num_samples, max_num_active_channels)
sparse_waveforms = np.zeros(shape=sparse_shape, dtype=self.templates_array.dtype)
sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels)
sparse_waveforms = np.zeros(shape=sparisfied_shape, dtype=self.templates_array.dtype)
for unit_index, unit_id in enumerate(self.unit_ids):
waveforms = self.templates_array[unit_index, ...]
sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id)
Expand Down

0 comments on commit 73e9562

Please sign in to comment.