Skip to content

Commit

Permalink
adding some typing (#2031)
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Sep 21, 2023
1 parent 419e3cd commit df0504c
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np

from .recording_tools import get_channel_distances, get_noise_levels
Expand Down Expand Up @@ -125,7 +127,7 @@ def unit_id_to_channel_indices(self):
self._unit_id_to_channel_indices[unit_id] = channel_inds
return self._unit_id_to_channel_indices

def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray:
def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray:
"""
Sparsify the waveforms according to a unit_id corresponding sparsity.
Expand Down Expand Up @@ -159,7 +161,7 @@ def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray:

return sparsified_waveforms

def densify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray:
def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray:
"""
Densify sparse waveforms that were sparisified according to a unit's channel sparsity.
Expand Down Expand Up @@ -199,7 +201,7 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray:
def are_waveforms_dense(self, waveforms: np.ndarray) -> bool:
return waveforms.shape[-1] == self.num_channels

def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str) -> bool:
def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> bool:
non_zero_indices = self.unit_id_to_channel_indices[unit_id]
num_active_channels = len(non_zero_indices)
return waveforms.shape[-1] == num_active_channels
Expand Down

0 comments on commit df0504c

Please sign in to comment.