From 9703af1c94174d8f0159ed97b91b2c2d9fcd970a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 5 Jul 2024 14:30:23 +0200 Subject: [PATCH] Regularize whitening (#2744) Regularize whitening --- .../preprocessing/tests/test_whiten.py | 12 +++-- src/spikeinterface/preprocessing/whiten.py | 48 ++++++++++++++++--- .../sorters/internal/spyking_circus2.py | 2 +- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index c3d1544869..04b731de4f 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -8,13 +8,13 @@ def test_whiten(create_cache_folder): cache_folder = create_cache_folder - rec = generate_recording(num_channels=4) + rec = generate_recording(num_channels=4, seed=2205) print(rec.get_channel_locations()) random_chunk_kwargs = {} - W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) - print(W) - print(M) + W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) + # print(W) + # print(M) with pytest.raises(AssertionError): W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None) @@ -41,6 +41,10 @@ def test_whiten(create_cache_folder): assert rec4.get_dtype() == "int16" assert rec4._kwargs["M"] is None + # test regularization : norm should be smaller + W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True) + assert np.linalg.norm(W1) > np.linalg.norm(W2) + if __name__ == "__main__": test_whiten() diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 874d4304e3..96cf5e028f 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -7,6 +7,7 @@ from ..core import get_random_data_chunks, get_channel_distances from .filter import fix_dtype +from ..core.globals import get_global_job_kwargs class WhitenRecording(BasePreprocessor): @@ -40,6 +41,12 @@ class WhitenRecording(BasePreprocessor): M : 1d np.array or None, default: None Pre-computed means. M can be None when previously computed with apply_mean=False + regularize : bool, default: False + Boolean to decide if we want to regularize the covariance matrix, using a chosen method + of sklearn, specified in regularize_kwargs. Default is GraphicalLassoCV + regularize_kwargs : {'method' : 'GraphicalLassoCV'} + Dictionary of the parameters that could be provided to the method of sklearn, if + the covariance matrix needs to be regularized. **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function Returns @@ -55,6 +62,8 @@ def __init__( recording, dtype=None, apply_mean=False, + regularize=False, + regularize_kwargs=None, mode="global", radius_um=100.0, int_scale=None, @@ -75,7 +84,14 @@ def __init__( M = np.asarray(M) else: W, M = compute_whitening_matrix( - recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps + recording, + mode, + random_chunk_kwargs, + apply_mean, + radius_um=radius_um, + eps=eps, + regularize=regularize, + regularize_kwargs=regularize_kwargs, ) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -90,6 +106,8 @@ def __init__( mode=mode, radius_um=radius_um, apply_mean=apply_mean, + regularize=regularize, + regularize_kwargs=regularize_kwargs, int_scale=float(int_scale) if int_scale is not None else None, M=M.tolist() if M is not None else None, W=W.tolist(), @@ -129,7 +147,9 @@ def get_traces(self, start_frame, end_frame, channel_indices): whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") -def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None): +def compute_whitening_matrix( + recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, regularize_kwargs=None +): """ Compute whitening matrix @@ -152,7 +172,12 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r eps : float or None, default: None Small epsilon to regularize SVD. If None, the default is set to 1e-8, but if the data is float type and scaled down to very small values, eps is automatically set to a small fraction (1e-3) of the median of the squared data. - + regularize : bool, default: False + Boolean to decide if we want to regularize the covariance matrix, using a chosen method + of sklearn, specified in regularize_kwargs. Default is GraphicalLassoCV + regularize_kwargs : {'method' : 'GraphicalLassoCV'} + Dictionary of the parameters that could be provided to the method of sklearn, if + the covariance matrix needs to be regularized. Returns ------- W : 2D array @@ -162,7 +187,8 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) - random_data = random_data.astype("float32") + + regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"} if apply_mean: M = np.mean(random_data, axis=0) @@ -172,8 +198,18 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r M = None data = random_data - cov = data.T @ data - cov = cov / data.shape[0] + if not regularize: + cov = data.T @ data + cov = cov / data.shape[0] + else: + import sklearn.covariance + + method = regularize_kwargs.pop("method") + regularize_kwargs["assume_centered"] = True + estimator_class = getattr(sklearn.covariance, method) + estimator = estimator_class(**regularize_kwargs) + estimator.fit(data) + cov = estimator.covariance_ # Here we determine eps used below to avoid division by zero. # Typically we can assume that data is either unscaled integers or in units of diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 45cc93d0b6..be75877f02 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -147,7 +147,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We need to whiten before the template matching step, to boost the results # TODO add , regularize=True chen ready - recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32") + recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True) noise_levels = get_noise_levels(recording_w, return_scaled=False)