Skip to content

Commit

Permalink
refine tests on whiten
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Jul 5, 2024
1 parent a837d5a commit 3d90bba
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
20 changes: 9 additions & 11 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

def test_whiten(create_cache_folder):
cache_folder = create_cache_folder
rec = generate_recording(num_channels=4)
rec = generate_recording(num_channels=4, seed=2205)

print(rec.get_channel_locations())
random_chunk_kwargs = {}
W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)
print(W)
print(M)
W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)
# print(W)
# print(M)

with pytest.raises(AssertionError):
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None)
Expand All @@ -41,13 +41,11 @@ def test_whiten(create_cache_folder):
assert rec4.get_dtype() == "int16"
assert rec4._kwargs["M"] is None

# test regularization
with pytest.raises(AssertionError):
W, M = compute_whitening_matrix(
rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, regularize=True
)
# W must be sparse
np.sum(W == 0) == 6
# test regularization : norm should be smaller
W2, M = compute_whitening_matrix(
rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True
)
assert np.linalg.norm(W1) > np.linalg.norm(W2)


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,12 @@ def compute_whitening_matrix(
cov = data.T @ data
cov = cov / data.shape[0]
else:
import sklearn.covariance as cov
import sklearn.covariance

method = regularize_kwargs.pop("method")
regularize_kwargs["assume_centered"] = True
estimator = eval(f"cov.{method}")(**regularize_kwargs)
estimator_class = getattr(sklearn.covariance, method)
estimator = estimator_class(**regularize_kwargs)
estimator.fit(data)
cov = estimator.covariance_

Expand Down

0 comments on commit 3d90bba

Please sign in to comment.