Skip to content

Commit

Permalink
Merge pull request #2070 from magland/adjust-eps-for-whitening
Browse files Browse the repository at this point in the history
Adjust eps for whitening in case of very small magnitude data
  • Loading branch information
alejoe91 authored Oct 26, 2023
2 parents 2956c27 + d93ba0f commit 67869c5
Showing 2 changed files with 38 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
@@ -20,13 +20,13 @@ def test_whiten():

print(rec.get_channel_locations())
random_chunk_kwargs = {}
W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8)
W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)
print(W)
print(M)

with pytest.raises(AssertionError):
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25, eps=1e-8)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25)
# W must be sparse
np.sum(W == 0) == 6

49 changes: 35 additions & 14 deletions src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
@@ -15,23 +15,27 @@ class WhitenRecording(BasePreprocessor):
----------
recording: RecordingExtractor
The recording extractor to be whitened.
dtype: None or dtype
dtype: None or dtype, default: None
If None the the parent dtype is kept.
For integer dtype a int_scale must be also given.
mode: 'global' / 'local'
mode: 'global' / 'local', default: 'global'
'global' use the entire covariance matrix to compute the W matrix
'local' use local covariance (by radius) to compute the W matrix
radius_um: None or float
radius_um: None or float, default: None
Used for mode = 'local' to get the neighborhood
apply_mean: bool
apply_mean: bool, default: False
Substract or not the mean matrix M before the dot product with W.
int_scale : None or float
int_scale : None or float, default: None
Apply a scaling factor to fit the integer range.
This is used when the dtype is an integer, so that the output is scaled.
For example, a value of `int_scale=200` will scale the traces value to a standard deviation of 200.
W : 2d np.array
Pre-computed whitening matrix, by default None
M : 1d np.array or None
eps : float or None, default: None
Small epsilon to regularize SVD.
If None, eps is defaulted to 1e-8. If the data is float type and scaled down to very small values,
then the eps is automatically set to a small fraction (1e-3) of the median of the squared data.
W : 2d np.array, default: None
Pre-computed whitening matrix
M : 1d np.array or None, default: None
Pre-computed means.
M can be None when previously computed with apply_mean=False
**random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function
@@ -52,6 +56,7 @@ def __init__(
mode="global",
radius_um=100.0,
int_scale=None,
eps=None,
W=None,
M=None,
**random_chunk_kwargs,
@@ -68,7 +73,7 @@ def __init__(
M = np.asarray(M)
else:
W, M = compute_whitening_matrix(
recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=1e-8
recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps
)

BasePreprocessor.__init__(self, recording, dtype=dtype_)
@@ -122,7 +127,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
whiten = define_function_from_class(source_class=WhitenRecording, name="whiten")


def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=1e-8):
def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None):
"""
Compute whitening matrix
@@ -140,10 +145,11 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r
Keyword arguments for get_random_data_chunks()
apply_mean : bool
If True, the mean is removed prior to computing the covariance
radius_um : float, optional
Used for mode = 'local' to get the neighborhood, by default None
eps : float, optional
Small epsilon to regularize SVD, by default 1e-8
radius_um : float, default: None
Used for mode = 'local' to get the neighborhood
eps : float, default: None
Small epsilon to regularize SVD. If None, the default is set to 1e-8, but if the data is float type and scaled
down to very small values, eps is automatically set to a small fraction (1e-3) of the median of the squared data.
Returns
-------
@@ -167,6 +173,21 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r
cov = data.T @ data
cov = cov / data.shape[0]

# Here we determine eps used below to avoid division by zero.
# Typically we can assume that data is either unscaled integers or in units of
# uV, but this is not always the case. When data
# is float type and scaled down to very small values, then the
# default eps=1e-8 can be too large, resulting in incorrect
# whitening. We therefore check to see if the data is float
# type and we estimate a more reasonable eps in the case
# where the data is on a scale less than 1.
if eps is None:
eps = 1e-8
if data.dtype.kind == "f":
median_data_sqr = np.median(data**2) # use the square because cov (and hence S) scales as the square
if median_data_sqr < 1 and median_data_sqr > 0:
eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data

if mode == "global":
U, S, Ut = np.linalg.svd(cov, full_matrices=True)
W = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut

0 comments on commit 67869c5

Please sign in to comment.