Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regularize whitening #2744

Merged
merged 46 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
9bda991
WIP
yger Apr 10, 2024
dc52a09
WIP
yger Apr 10, 2024
4edbfcb
Merge branch 'SpikeInterface:main' into regularize_whitening
yger Apr 12, 2024
62c6f82
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 12, 2024
ae7835c
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 19, 2024
8028ef8
Merge branch 'SpikeInterface:main' into regularize_whitening
yger Apr 23, 2024
e726cdd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
e627fee
Merge branch 'SpikeInterface:main' into regularize_whitening
yger May 21, 2024
54c5723
Merge branch 'main' into regularize_whitening
yger May 23, 2024
627f280
Not centering the whitening
yger May 23, 2024
62bc321
typo
yger May 23, 2024
23d221d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
d66882e
typo
yger May 23, 2024
a3474d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
a329a80
WIP
yger May 24, 2024
5b664b6
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger May 24, 2024
842aaca
Merge branch 'SpikeInterface:main' into regularize_whitening
yger May 24, 2024
d693d13
Merge branch 'SpikeInterface:main' into regularize_whitening
yger May 25, 2024
d83d0c1
Adding tests and documentation, support for parallelism while whitening
yger May 28, 2024
b2e28f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 28, 2024
3b8eabc
WIP
yger May 29, 2024
853886c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2024
5ccd5dd
Merge branch 'main' of github.com:yger/spikeinterface into regularize…
yger May 31, 2024
83d2087
Merge branch 'SpikeInterface:main' into regularize_whitening
yger May 31, 2024
6e6e033
Merge branch 'SpikeInterface:main' into regularize_whitening
yger Jun 2, 2024
a0b96e8
Merge branch 'main' into regularize_whitening
yger Jun 4, 2024
8231a0c
Update src/spikeinterface/preprocessing/whiten.py
yger Jun 5, 2024
8c1bddc
Update src/spikeinterface/preprocessing/whiten.py
yger Jun 5, 2024
39ad18b
Update src/spikeinterface/preprocessing/whiten.py
yger Jun 5, 2024
13e0634
Update src/spikeinterface/preprocessing/whiten.py
yger Jun 5, 2024
35b2299
Fixes for regularization
yger Jun 5, 2024
68dcaa5
Fixes
yger Jun 5, 2024
81fa6bb
Merge branch 'main' into regularize_whitening
yger Jun 5, 2024
787a89c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2024
343d5d2
WIP
yger Jun 5, 2024
4de2bdc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2024
b4b76e7
WIP
yger Jun 5, 2024
ed61efd
WIP
yger Jun 5, 2024
34615cd
Merge branch 'SpikeInterface:main' into regularize_whitening
yger Jun 7, 2024
d8706ba
Merge branch 'main' into regularize_whitening
yger Jun 12, 2024
dbbe2ff
Merge branch 'SpikeInterface:main' into regularize_whitening
yger Jun 19, 2024
be7683f
Merge branch 'main' into regularize_whitening
yger Jun 25, 2024
bad4152
Merge branch 'SpikeInterface:main' into regularize_whitening
yger Jun 29, 2024
a837d5a
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Jul 5, 2024
3d90bba
refine tests on whiten
samuelgarcia Jul 5, 2024
88af862
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

def test_whiten(create_cache_folder):
cache_folder = create_cache_folder
rec = generate_recording(num_channels=4)
rec = generate_recording(num_channels=4, seed=2205)

print(rec.get_channel_locations())
random_chunk_kwargs = {}
W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)
print(W)
print(M)
W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)
# print(W)
# print(M)

with pytest.raises(AssertionError):
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None)
Expand All @@ -41,6 +41,12 @@ def test_whiten(create_cache_folder):
assert rec4.get_dtype() == "int16"
assert rec4._kwargs["M"] is None

# test regularization : norm should be smaller
W2, M = compute_whitening_matrix(
rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True
)
assert np.linalg.norm(W1) > np.linalg.norm(W2)


if __name__ == "__main__":
test_whiten()
48 changes: 42 additions & 6 deletions src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..core import get_random_data_chunks, get_channel_distances
from .filter import fix_dtype
from ..core.globals import get_global_job_kwargs


class WhitenRecording(BasePreprocessor):
Expand Down Expand Up @@ -40,6 +41,12 @@ class WhitenRecording(BasePreprocessor):
M : 1d np.array or None, default: None
Pre-computed means.
M can be None when previously computed with apply_mean=False
regularize : bool, default: False
Boolean to decide if we want to regularize the covariance matrix, using a chosen method
of sklearn, specified in regularize_kwargs. Default is GraphicalLassoCV
regularize_kwargs : {'method' : 'GraphicalLassoCV'}
Dictionary of the parameters that could be provided to the method of sklearn, if
the covariance matrix needs to be regularized.
**random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function

Returns
Expand All @@ -55,6 +62,8 @@ def __init__(
recording,
dtype=None,
apply_mean=False,
regularize=False,
regularize_kwargs=None,
mode="global",
radius_um=100.0,
int_scale=None,
Expand All @@ -75,7 +84,14 @@ def __init__(
M = np.asarray(M)
else:
W, M = compute_whitening_matrix(
recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps
recording,
mode,
random_chunk_kwargs,
apply_mean,
radius_um=radius_um,
eps=eps,
regularize=regularize,
regularize_kwargs=regularize_kwargs,
)

BasePreprocessor.__init__(self, recording, dtype=dtype_)
Expand All @@ -90,6 +106,8 @@ def __init__(
mode=mode,
radius_um=radius_um,
apply_mean=apply_mean,
regularize=regularize,
regularize_kwargs=regularize_kwargs,
int_scale=float(int_scale) if int_scale is not None else None,
M=M.tolist() if M is not None else None,
W=W.tolist(),
Expand Down Expand Up @@ -129,7 +147,9 @@ def get_traces(self, start_frame, end_frame, channel_indices):
whiten = define_function_from_class(source_class=WhitenRecording, name="whiten")


def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None):
def compute_whitening_matrix(
recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, regularize_kwargs=None
):
"""
Compute whitening matrix

Expand All @@ -152,7 +172,12 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r
eps : float or None, default: None
Small epsilon to regularize SVD. If None, the default is set to 1e-8, but if the data is float type and scaled
down to very small values, eps is automatically set to a small fraction (1e-3) of the median of the squared data.

regularize : bool, default: False
Boolean to decide if we want to regularize the covariance matrix, using a chosen method
of sklearn, specified in regularize_kwargs. Default is GraphicalLassoCV
regularize_kwargs : {'method' : 'GraphicalLassoCV'}
Dictionary of the parameters that could be provided to the method of sklearn, if
the covariance matrix needs to be regularized.
Returns
-------
W : 2D array
Expand All @@ -162,7 +187,8 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r

"""
random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs)
random_data = random_data.astype("float32")

regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"}

if apply_mean:
M = np.mean(random_data, axis=0)
Expand All @@ -172,8 +198,18 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r
M = None
data = random_data

cov = data.T @ data
cov = cov / data.shape[0]
if not regularize:
cov = data.T @ data
cov = cov / data.shape[0]
else:
import sklearn.covariance

method = regularize_kwargs.pop("method")
regularize_kwargs["assume_centered"] = True
estimator_class = getattr(sklearn.covariance, method)
estimator = estimator_class(**regularize_kwargs)
estimator.fit(data)
cov = estimator.covariance_

# Here we determine eps used below to avoid division by zero.
# Typically we can assume that data is either unscaled integers or in units of
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

## We need to whiten before the template matching step, to boost the results
# TODO add , regularize=True chen ready
recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32")
recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True)

noise_levels = get_noise_levels(recording_w, return_scaled=False)

Expand Down
Loading