diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index cb2346ba68..c8eece2623 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -68,7 +68,7 @@ def __init__( M = np.asarray(M) else: W, M = compute_whitening_matrix( - recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=1e-8 + recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um ) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -122,7 +122,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") -def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=1e-8): +def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None): """ Compute whitening matrix @@ -167,6 +167,20 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r cov = data.T @ data cov = cov / data.shape[0] + # Here we determine eps used below to avoid division by zero. + # Typically we can assume that data is in units of + # microvolts, but this is not always the case. When data + # is float type and scaled down to very small values, then the + # default eps=1e-6 can be too large, resulting in incorrect + # whitening. We therefore check to see if the data is float + # type and we estimate a more reasonable eps in the case + # where the data is on a scale less than 1. + eps = 1e-6 # the default + if data.dtype.kind == "f": + median_data_sqr = np.median(data ** 2) # use the square because cov (and hence S) scales as the square + if median_data_sqr < 1 and median_data_sqr > 0: + eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data + if mode == "global": U, S, Ut = np.linalg.svd(cov, full_matrices=True) W = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut