-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Regularize whitening #2744
Regularize whitening #2744
Changes from 26 commits
9bda991
dc52a09
4edbfcb
62c6f82
ae7835c
8028ef8
e726cdd
e627fee
54c5723
627f280
62bc321
23d221d
d66882e
a3474d0
a329a80
5b664b6
842aaca
d693d13
d83d0c1
b2e28f7
3b8eabc
853886c
5ccd5dd
83d2087
6e6e033
a0b96e8
8231a0c
8c1bddc
39ad18b
13e0634
35b2299
68dcaa5
81fa6bb
787a89c
343d5d2
4de2bdc
b4b76e7
ed61efd
34615cd
d8706ba
dbbe2ff
be7683f
bad4152
a837d5a
3d90bba
88af862
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,6 +49,14 @@ def test_whiten(): | |
assert rec4.get_dtype() == "int16" | ||
assert rec4._kwargs["M"] is None | ||
|
||
# test regularization | ||
with pytest.raises(AssertionError): | ||
W, M = compute_whitening_matrix( | ||
rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, regularize=True | ||
) | ||
# W must be sparse | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. which W are you testing here ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm just computing W with the default regularization method |
||
np.sum(W == 0) == 6 | ||
|
||
|
||
if __name__ == "__main__": | ||
test_whiten() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
|
||
from ..core import get_random_data_chunks, get_channel_distances | ||
from .filter import fix_dtype | ||
from ..core.globals import get_global_job_kwargs | ||
|
||
|
||
class WhitenRecording(BasePreprocessor): | ||
|
@@ -40,6 +41,12 @@ class WhitenRecording(BasePreprocessor): | |
M : 1d np.array or None, default: None | ||
Pre-computed means. | ||
M can be None when previously computed with apply_mean=False | ||
regularize: bool, default False | ||
Boolean to decide if we want to regularize the covariance matrix, using the GraphicalLassoCV() method | ||
of sklearn | ||
regularize_kwargs: None or dict | ||
Dictionary of the parameters that could be provided to the GraphicalLassoCV() method of sklearn, if | ||
the covariance matrix needs to be regularized | ||
yger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function | ||
|
||
Returns | ||
|
@@ -55,6 +62,8 @@ def __init__( | |
recording, | ||
dtype=None, | ||
apply_mean=False, | ||
regularize=False, | ||
regularize_kwargs={}, | ||
yger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mode="global", | ||
radius_um=100.0, | ||
int_scale=None, | ||
|
@@ -75,7 +84,14 @@ def __init__( | |
M = np.asarray(M) | ||
else: | ||
W, M = compute_whitening_matrix( | ||
recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps | ||
recording, | ||
mode, | ||
random_chunk_kwargs, | ||
apply_mean, | ||
radius_um=radius_um, | ||
eps=eps, | ||
regularize=regularize, | ||
regularize_kwargs=regularize_kwargs, | ||
) | ||
|
||
BasePreprocessor.__init__(self, recording, dtype=dtype_) | ||
|
@@ -90,6 +106,8 @@ def __init__( | |
mode=mode, | ||
radius_um=radius_um, | ||
apply_mean=apply_mean, | ||
regularize=regularize, | ||
regularize_kwargs=regularize_kwargs, | ||
int_scale=float(int_scale) if int_scale is not None else None, | ||
M=M.tolist() if M is not None else None, | ||
W=W.tolist(), | ||
|
@@ -129,7 +147,9 @@ def get_traces(self, start_frame, end_frame, channel_indices): | |
whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") | ||
|
||
|
||
def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None): | ||
def compute_whitening_matrix( | ||
recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, regularize_kwargs=None | ||
): | ||
""" | ||
Compute whitening matrix | ||
|
||
|
@@ -152,7 +172,12 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r | |
eps : float or None, default: None | ||
Small epsilon to regularize SVD. If None, the default is set to 1e-8, but if the data is float type and scaled | ||
down to very small values, eps is automatically set to a small fraction (1e-3) of the median of the squared data. | ||
|
||
regularize: bool, default False | ||
Boolean to decide if we want to regularize the covariance matrix, using the GraphicalLassoCV() method | ||
of sklearn | ||
regularize_kwargs: None or dict | ||
Dictionary of the parameters that could be provided to the GraphicalLassoCV() method of sklearn, if | ||
the covariance matrix needs to be regularized | ||
yger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Returns | ||
------- | ||
W : 2D array | ||
|
@@ -162,7 +187,7 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r | |
|
||
""" | ||
random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) | ||
random_data = random_data.astype("float32") | ||
random_data = random_data | ||
yger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if apply_mean: | ||
M = np.mean(random_data, axis=0) | ||
|
@@ -172,8 +197,26 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r | |
M = None | ||
data = random_data | ||
|
||
cov = data.T @ data | ||
cov = cov / data.shape[0] | ||
if not regularize: | ||
cov = data.T @ data | ||
cov = cov / data.shape[0] | ||
else: | ||
import sklearn.covariance | ||
|
||
if regularize_kwargs is None: | ||
regularize_kwargs = {} | ||
regularize_kwargs["assume_centered"] = True | ||
job_kwargs = get_global_job_kwargs() | ||
if "n_jobs" in job_kwargs and "n_jobs" not in regularize_kwargs: | ||
n_jobs = job_kwargs["n_jobs"] | ||
if isinstance(n_jobs, float) and 0 < n_jobs <= 1: | ||
import os | ||
|
||
n_jobs = int(n_jobs * os.cpu_count()) | ||
yger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
regularize_kwargs["n_jobs"] = n_jobs | ||
estimator = sklearn.covariance.GraphicalLassoCV(**regularize_kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we could have a regularize method that can use lasso but maybe others no ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could, but I would say let's start with one... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes but we could have already the regularized_method in the dict no ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've tried something... We can pass method: "GraphicalLassoCV" or something else, and the appropriate node is created |
||
estimator.fit(data) | ||
cov = estimator.covariance_ | ||
|
||
# Here we determine eps used below to avoid division by zero. | ||
# Typically we can assume that data is either unscaled integers or in units of | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not see in the code wich assert is raised for this.
We should have better error than assert in this no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've used what was already done, open to suggestions