diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 54f6e0e903..ec3a3e91a9 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -15,27 +15,27 @@ class WhitenRecording(BasePreprocessor): ---------- recording: RecordingExtractor The recording extractor to be whitened. - dtype: None or dtype + dtype: None or dtype, default: None If None the the parent dtype is kept. For integer dtype a int_scale must be also given. - mode: 'global' / 'local' + mode: 'global' / 'local', default: 'global' 'global' use the entire covariance matrix to compute the W matrix 'local' use local covariance (by radius) to compute the W matrix - radius_um: None or float + radius_um: None or float, default: None Used for mode = 'local' to get the neighborhood - apply_mean: bool + apply_mean: bool, default: False Substract or not the mean matrix M before the dot product with W. - int_scale : None or float + int_scale : None or float, default: None Apply a scaling factor to fit the integer range. This is used when the dtype is an integer, so that the output is scaled. For example, a value of `int_scale=200` will scale the traces value to a standard deviation of 200. - eps : float, default 1e-8 + eps : float or None, default: None Small epsilon to regularize SVD. - If None, eps is estimated from the data. If the data is float type and scaled down to very small values, - then the eps is automatically set to a small fraction of the median of the squared data. - W : 2d np.array - Pre-computed whitening matrix, by default None - M : 1d np.array or None + If None, eps is default to 1e-8. If the data is float type and scaled down to very small values, + then the eps is automatically set to a small fraction (1e-3) of the median of the squared data. + W : 2d np.array, default: None + Pre-computed whitening matrix + M : 1d np.array or None, default: None Pre-computed means. M can be None when previously computed with apply_mean=False **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function @@ -56,7 +56,7 @@ def __init__( mode="global", radius_um=100.0, int_scale=None, - eps=1e-8, + eps=None, W=None, M=None, **random_chunk_kwargs, @@ -127,7 +127,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") -def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=1e-8): +def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None): """ Compute whitening matrix @@ -145,10 +145,11 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r Keyword arguments for get_random_data_chunks() apply_mean : bool If True, the mean is removed prior to computing the covariance - radius_um : float, optional - Used for mode = 'local' to get the neighborhood, by default None - eps : float, optional - Small epsilon to regularize SVD, by default 1e-8 + radius_um : float, default: None + Used for mode = 'local' to get the neighborhood + eps : float, default: None + Small epsilon to regularize SVD. If None, the default is set to 1e-8, but if the data is float type and scaled + down to very small values, eps is automatically set to a small fraction (1e-3) of the median of the squared data. Returns ------- @@ -180,7 +181,9 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r # whitening. We therefore check to see if the data is float # type and we estimate a more reasonable eps in the case # where the data is on a scale less than 1. - if data.dtype.kind == "f" or eps is None: + if eps is None: + eps = 1e-8 + if data.dtype.kind == "f": median_data_sqr = np.median(data**2) # use the square because cov (and hence S) scales as the square if median_data_sqr < 1 and median_data_sqr > 0: eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data