Skip to content

Commit

Permalink
adjust eps for whitening in case of very small magnitude data
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Oct 4, 2023
1 parent ca48f6b commit 6c57b5c
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
M = np.asarray(M)
else:
W, M = compute_whitening_matrix(
recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=1e-8
recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um
)

BasePreprocessor.__init__(self, recording, dtype=dtype_)
Expand Down Expand Up @@ -122,7 +122,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
whiten = define_function_from_class(source_class=WhitenRecording, name="whiten")


def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=1e-8):
def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None):
"""
Compute whitening matrix
Expand Down Expand Up @@ -167,6 +167,20 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r
cov = data.T @ data
cov = cov / data.shape[0]

# Here we determine eps used below to avoid division by zero.
# Typically we can assume that data is in units of
# microvolts, but this is not always the case. When data
# is float type and scaled down to very small values, then the
# default eps=1e-6 can be too large, resulting in incorrect
# whitening. We therefore check to see if the data is float
# type and we estimate a more reasonable eps in the case
# where the data is on a scale less than 1.
eps = 1e-6 # the default
if data.dtype.kind == "f":
median_data_sqr = np.median(data ** 2) # use the square because cov (and hence S) scales as the square
if median_data_sqr < 1 and median_data_sqr > 0:
eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data

if mode == "global":
U, S, Ut = np.linalg.svd(cov, full_matrices=True)
W = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut
Expand Down

0 comments on commit 6c57b5c

Please sign in to comment.