diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 0848c1a176..40674a08f4 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -20,13 +20,13 @@ def test_whiten(): print(rec.get_channel_locations()) random_chunk_kwargs = {} - W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8) + W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) print(W) print(M) with pytest.raises(AssertionError): - W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8) - W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25, eps=1e-8) + W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None) + W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25) # W must be sparse np.sum(W == 0) == 6 diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index cb2346ba68..ac80f58182 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -15,23 +15,27 @@ class WhitenRecording(BasePreprocessor): ---------- recording: RecordingExtractor The recording extractor to be whitened. - dtype: None or dtype + dtype: None or dtype, default: None If None the the parent dtype is kept. For integer dtype a int_scale must be also given. - mode: 'global' / 'local' + mode: 'global' / 'local', default: 'global' 'global' use the entire covariance matrix to compute the W matrix 'local' use local covariance (by radius) to compute the W matrix - radius_um: None or float + radius_um: None or float, default: None Used for mode = 'local' to get the neighborhood - apply_mean: bool + apply_mean: bool, default: False Substract or not the mean matrix M before the dot product with W. - int_scale : None or float + int_scale : None or float, default: None Apply a scaling factor to fit the integer range. This is used when the dtype is an integer, so that the output is scaled. For example, a value of `int_scale=200` will scale the traces value to a standard deviation of 200. - W : 2d np.array - Pre-computed whitening matrix, by default None - M : 1d np.array or None + eps : float or None, default: None + Small epsilon to regularize SVD. + If None, eps is defaulted to 1e-8. If the data is float type and scaled down to very small values, + then the eps is automatically set to a small fraction (1e-3) of the median of the squared data. + W : 2d np.array, default: None + Pre-computed whitening matrix + M : 1d np.array or None, default: None Pre-computed means. M can be None when previously computed with apply_mean=False **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function @@ -52,6 +56,7 @@ def __init__( mode="global", radius_um=100.0, int_scale=None, + eps=None, W=None, M=None, **random_chunk_kwargs, @@ -68,7 +73,7 @@ def __init__( M = np.asarray(M) else: W, M = compute_whitening_matrix( - recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=1e-8 + recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps ) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -122,7 +127,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") -def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=1e-8): +def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None): """ Compute whitening matrix @@ -140,10 +145,11 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r Keyword arguments for get_random_data_chunks() apply_mean : bool If True, the mean is removed prior to computing the covariance - radius_um : float, optional - Used for mode = 'local' to get the neighborhood, by default None - eps : float, optional - Small epsilon to regularize SVD, by default 1e-8 + radius_um : float, default: None + Used for mode = 'local' to get the neighborhood + eps : float, default: None + Small epsilon to regularize SVD. If None, the default is set to 1e-8, but if the data is float type and scaled + down to very small values, eps is automatically set to a small fraction (1e-3) of the median of the squared data. Returns ------- @@ -167,6 +173,21 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r cov = data.T @ data cov = cov / data.shape[0] + # Here we determine eps used below to avoid division by zero. + # Typically we can assume that data is either unscaled integers or in units of + # uV, but this is not always the case. When data + # is float type and scaled down to very small values, then the + # default eps=1e-8 can be too large, resulting in incorrect + # whitening. We therefore check to see if the data is float + # type and we estimate a more reasonable eps in the case + # where the data is on a scale less than 1. + if eps is None: + eps = 1e-8 + if data.dtype.kind == "f": + median_data_sqr = np.median(data**2) # use the square because cov (and hence S) scales as the square + if median_data_sqr < 1 and median_data_sqr > 0: + eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data + if mode == "global": U, S, Ut = np.linalg.svd(cov, full_matrices=True) W = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut