Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust eps for whitening in case of very small magnitude data #2070

Merged
merged 11 commits into from
Oct 26, 2023
6 changes: 3 additions & 3 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def test_whiten():

print(rec.get_channel_locations())
random_chunk_kwargs = {}
W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8)
W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)
print(W)
print(M)

with pytest.raises(AssertionError):
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25, eps=1e-8)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25)
# W must be sparse
np.sum(W == 0) == 6

Expand Down
20 changes: 19 additions & 1 deletion src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class WhitenRecording(BasePreprocessor):
Apply a scaling factor to fit the integer range.
This is used when the dtype is an integer, so that the output is scaled.
For example, a value of `int_scale=200` will scale the traces value to a standard deviation of 200.
eps : float, default 1e-8
Small epsilon to regularize SVD.
If None, eps is estimated from the data. If the data is float type and scaled down to very small values,
then the eps is automatically set to a small fraction of the median of the squared data.
W : 2d np.array
Pre-computed whitening matrix, by default None
M : 1d np.array or None
Expand All @@ -52,6 +56,7 @@ def __init__(
mode="global",
radius_um=100.0,
int_scale=None,
eps=1e-8,
W=None,
M=None,
**random_chunk_kwargs,
Expand All @@ -68,7 +73,7 @@ def __init__(
M = np.asarray(M)
else:
W, M = compute_whitening_matrix(
recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=1e-8
recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps
)

BasePreprocessor.__init__(self, recording, dtype=dtype_)
Expand Down Expand Up @@ -167,6 +172,19 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r
cov = data.T @ data
cov = cov / data.shape[0]

# Here we determine eps used below to avoid division by zero.
# Typically we can assume that data is either unscaled integers or in units of
# uV, but this is not always the case. When data
# is float type and scaled down to very small values, then the
# default eps=1e-8 can be too large, resulting in incorrect
# whitening. We therefore check to see if the data is float
# type and we estimate a more reasonable eps in the case
# where the data is on a scale less than 1.
if data.dtype.kind == "f" or eps is None:
median_data_sqr = np.median(data**2) # use the square because cov (and hence S) scales as the square
if median_data_sqr < 1 and median_data_sqr > 0:
eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data

if mode == "global":
U, S, Ut = np.linalg.svd(cov, full_matrices=True)
W = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut
Expand Down