Skip to content

Commit

Permalink
Regularize whitening (#2744)
Browse files Browse the repository at this point in the history
Regularize whitening
  • Loading branch information
yger authored Jul 5, 2024
1 parent 18c1674 commit 9703af1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 11 deletions.
12 changes: 8 additions & 4 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

def test_whiten(create_cache_folder):
cache_folder = create_cache_folder
rec = generate_recording(num_channels=4)
rec = generate_recording(num_channels=4, seed=2205)

print(rec.get_channel_locations())
random_chunk_kwargs = {}
W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)
print(W)
print(M)
W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)
# print(W)
# print(M)

with pytest.raises(AssertionError):
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None)
Expand All @@ -41,6 +41,10 @@ def test_whiten(create_cache_folder):
assert rec4.get_dtype() == "int16"
assert rec4._kwargs["M"] is None

# test regularization : norm should be smaller
W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True)
assert np.linalg.norm(W1) > np.linalg.norm(W2)


if __name__ == "__main__":
test_whiten()
48 changes: 42 additions & 6 deletions src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..core import get_random_data_chunks, get_channel_distances
from .filter import fix_dtype
from ..core.globals import get_global_job_kwargs


class WhitenRecording(BasePreprocessor):
Expand Down Expand Up @@ -40,6 +41,12 @@ class WhitenRecording(BasePreprocessor):
M : 1d np.array or None, default: None
Pre-computed means.
M can be None when previously computed with apply_mean=False
regularize : bool, default: False
Boolean to decide if we want to regularize the covariance matrix, using a chosen method
of sklearn, specified in regularize_kwargs. Default is GraphicalLassoCV
regularize_kwargs : {'method' : 'GraphicalLassoCV'}
Dictionary of the parameters that could be provided to the method of sklearn, if
the covariance matrix needs to be regularized.
**random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function
Returns
Expand All @@ -55,6 +62,8 @@ def __init__(
recording,
dtype=None,
apply_mean=False,
regularize=False,
regularize_kwargs=None,
mode="global",
radius_um=100.0,
int_scale=None,
Expand All @@ -75,7 +84,14 @@ def __init__(
M = np.asarray(M)
else:
W, M = compute_whitening_matrix(
recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps
recording,
mode,
random_chunk_kwargs,
apply_mean,
radius_um=radius_um,
eps=eps,
regularize=regularize,
regularize_kwargs=regularize_kwargs,
)

BasePreprocessor.__init__(self, recording, dtype=dtype_)
Expand All @@ -90,6 +106,8 @@ def __init__(
mode=mode,
radius_um=radius_um,
apply_mean=apply_mean,
regularize=regularize,
regularize_kwargs=regularize_kwargs,
int_scale=float(int_scale) if int_scale is not None else None,
M=M.tolist() if M is not None else None,
W=W.tolist(),
Expand Down Expand Up @@ -129,7 +147,9 @@ def get_traces(self, start_frame, end_frame, channel_indices):
whiten = define_function_from_class(source_class=WhitenRecording, name="whiten")


def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None):
def compute_whitening_matrix(
recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, regularize_kwargs=None
):
"""
Compute whitening matrix
Expand All @@ -152,7 +172,12 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r
eps : float or None, default: None
Small epsilon to regularize SVD. If None, the default is set to 1e-8, but if the data is float type and scaled
down to very small values, eps is automatically set to a small fraction (1e-3) of the median of the squared data.
regularize : bool, default: False
Boolean to decide if we want to regularize the covariance matrix, using a chosen method
of sklearn, specified in regularize_kwargs. Default is GraphicalLassoCV
regularize_kwargs : {'method' : 'GraphicalLassoCV'}
Dictionary of the parameters that could be provided to the method of sklearn, if
the covariance matrix needs to be regularized.
Returns
-------
W : 2D array
Expand All @@ -162,7 +187,8 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r
"""
random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs)
random_data = random_data.astype("float32")

regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"}

if apply_mean:
M = np.mean(random_data, axis=0)
Expand All @@ -172,8 +198,18 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r
M = None
data = random_data

cov = data.T @ data
cov = cov / data.shape[0]
if not regularize:
cov = data.T @ data
cov = cov / data.shape[0]
else:
import sklearn.covariance

method = regularize_kwargs.pop("method")
regularize_kwargs["assume_centered"] = True
estimator_class = getattr(sklearn.covariance, method)
estimator = estimator_class(**regularize_kwargs)
estimator.fit(data)
cov = estimator.covariance_

# Here we determine eps used below to avoid division by zero.
# Typically we can assume that data is either unscaled integers or in units of
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

## We need to whiten before the template matching step, to boost the results
# TODO add , regularize=True chen ready
recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32")
recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True)

noise_levels = get_noise_levels(recording_w, return_scaled=False)

Expand Down

0 comments on commit 9703af1

Please sign in to comment.