Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust eps for whitening in case of very small magnitude data #2070

Merged
merged 11 commits into from
Oct 26, 2023
6 changes: 3 additions & 3 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def test_whiten():

print(rec.get_channel_locations())
random_chunk_kwargs = {}
W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8)
W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)
print(W)
print(M)

with pytest.raises(AssertionError):
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25, eps=1e-8)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25)
# W must be sparse
np.sum(W == 0) == 6

Expand Down
20 changes: 16 additions & 4 deletions src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def __init__(
if M is not None:
M = np.asarray(M)
else:
W, M = compute_whitening_matrix(
recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=1e-8
)
W, M = compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um)

BasePreprocessor.__init__(self, recording, dtype=dtype_)

Expand Down Expand Up @@ -122,7 +120,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
whiten = define_function_from_class(source_class=WhitenRecording, name="whiten")


def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=1e-8):
def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None):
"""
Compute whitening matrix

Expand Down Expand Up @@ -167,6 +165,20 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r
cov = data.T @ data
cov = cov / data.shape[0]

# Here we determine eps used below to avoid division by zero.
# Typically we can assume that data is in units of
# mV, but this is not always the case. When data
magland marked this conversation as resolved.
Show resolved Hide resolved
# is float type and scaled down to very small values, then the
# default eps=1e-8 can be too large, resulting in incorrect
# whitening. We therefore check to see if the data is float
# type and we estimate a more reasonable eps in the case
# where the data is on a scale less than 1.
eps = 1e-8 # the default
if data.dtype.kind == "f":
median_data_sqr = np.median(data**2) # use the square because cov (and hence S) scales as the square
if median_data_sqr < 1 and median_data_sqr > 0:
eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data

if mode == "global":
U, S, Ut = np.linalg.svd(cov, full_matrices=True)
W = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut
Expand Down