From 4a5ff585a45bfe34cdfaa07a99c816135bfefb10 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 15 Nov 2024 11:02:27 +0000 Subject: [PATCH] Rework to avoid introducing a 'update_traces' function. --- src/spikeinterface/core/numpyextractors.py | 18 ----- .../preprocessing/tests/test_whiten.py | 74 +++++++++---------- 2 files changed, 34 insertions(+), 58 deletions(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 406f1372b6..f4790817a8 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -38,7 +38,6 @@ class NumpyRecording(BaseRecording): """ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=None): - if isinstance(traces_list, list): all_elements_are_list = all(isinstance(e, list) for e in traces_list) if all_elements_are_list: @@ -104,23 +103,6 @@ def from_recording(source_recording, **job_kwargs): ) return recording - def update_traces(self, traces, segment_index=0): - """ - Set the `traces` on on the segment of index `segment_index`. - `traces` must be the same size (num_samples, num_channels) - and dtype as the recording. - """ - if traces.shape[0] != self.get_num_samples(segment_index=segment_index): - raise ValueError("The first dimension must be the same size as" "the number of samples.") - - if traces.shape[1] != self.get_num_channels(): - raise ValueError("The second dimension of the data be the same" "size as the number of channels.") - - if traces.dtype != self.dtype: - raise ValueError("The dtype of the data must match the recording dtype.") - - self._recording_segments[segment_index]._traces = traces - class NumpyRecordingSegment(BaseRecordingSegment): def __init__(self, traces, sampling_frequency, t_start): diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 0a594f97a3..7aa3bf6705 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -30,7 +30,7 @@ class TestWhiten: returned data is indeed white. """ - def get_float_test_data(self, num_segments, dtype, means=None): + def get_test_recording(self, num_segments, dtype, means=None): """ Generate a set of test data with known covariance matrix and mean. Test data is drawn from a multivariate Gaussian distribute with @@ -57,6 +57,17 @@ def get_float_test_data(self, num_segments, dtype, means=None): The `means` should be an array of length 3 (num samples) or `None`. If `None`, means will be zero. """ + sampling_frequency = 30000 + num_samples = int(10 * sampling_frequency) # 10 s recording + + means, cov_mat, data = self.get_test_data_with_known_distribution(num_samples, dtype, means) + + recording = NumpyRecording([data], sampling_frequency) + + return means, cov_mat, recording + + def get_test_data_with_known_distribution(self, num_samples, dtype, means=None): + """ """ num_channels = 3 if means is None: @@ -64,38 +75,19 @@ def get_float_test_data(self, num_segments, dtype, means=None): cov_mat = np.array([[1, 0.5, 0], [0.5, 1, -0.25], [0, -0.25, 1]]) - # Generate recording and multivariate Gaussian data to set - recording = self.get_empty_custom_recording(num_segments, num_channels, dtype) - - seg_1_data = np.random.multivariate_normal(means, cov_mat, recording.get_num_samples(segment_index=0)) + data = np.random.multivariate_normal(means, cov_mat, num_samples) # Set the dtype, if `int16`, first scale to +/- 1 then cast to int16 range. if dtype == np.float32: - seg_1_data = seg_1_data.astype(dtype) + data = data.astype(dtype) elif dtype == np.int16: - seg_1_data /= seg_1_data.max() - seg_1_data = np.round((seg_1_data) * 32767).astype(np.int16) + data /= data.max() + data = np.round((data) * 32767).astype(np.int16) else: raise ValueError("dtype must be float32 or int16") - # Set the data on the recording and return - recording.update_traces(seg_1_data) - assert np.array_equal(recording.get_traces(segment_index=0), seg_1_data), "segment 1 test setup did not work." - - return means, cov_mat, recording - - def get_empty_custom_recording(self, num_segments, num_channels, dtype): - - sampling_frequency = 30000 - num_samples = int(10 * sampling_frequency) - - traces_list = [np.zeros((num_samples, num_channels), dtype=dtype) for _ in range(num_segments)] - - return NumpyRecording( - traces_list, - sampling_frequency=30000, - ) + return means, cov_mat, data def cov_mat_from_whitening_mat(self, whitened_recording, eps): """ @@ -147,7 +139,7 @@ def test_compute_covariance_matrix(self, dtype): otherwise it can overflow. """ eps = 1e-16 - _, cov_mat, recording = self.get_float_test_data(num_segments=1, dtype=dtype) + _, cov_mat, recording = self.get_test_recording(num_segments=1, dtype=dtype) whitened_recording = whiten( recording, @@ -177,7 +169,7 @@ def test_non_default_eps(self): the cov mat if the correct eps is used. """ eps = 1 - _, cov_mat, recording = self.get_float_test_data(num_segments=1, dtype=np.float32) + _, cov_mat, recording = self.get_test_recording(num_segments=1, dtype=np.float32) whitened_recording = whiten( recording, @@ -200,17 +192,14 @@ def test_compute_covariance_matrix_2_segments(self): but the covariance matrix is scaled by 1 / N. """ eps = 1e-16 - _, cov_mat, recording = self.get_float_test_data(num_segments=2, dtype=np.float32) + sampling_frequency = 30000 + num_samples = 10 * sampling_frequency - all_zero_data = np.zeros( - (recording.get_num_samples(segment_index=0), recording.get_num_channels()), - dtype=np.float32, - ) + _, cov_mat, data = self.get_test_data_with_known_distribution(num_samples, np.float32) - recording.update_traces( - all_zero_data, - segment_index=1, - ) + traces_list = [data, np.zeros_like(data)] + + recording = NumpyRecording(traces_list, sampling_frequency) whitened_recording = whiten( recording, @@ -238,7 +227,7 @@ def test_apply_mean(self, apply_mean): means = np.array([10, 20, 30]) eps = 1e-16 - _, cov_mat, recording = self.get_float_test_data(num_segments=1, dtype=np.float32, means=means) + _, cov_mat, recording = self.get_test_recording(num_segments=1, dtype=np.float32, means=means) whitened_recording = whiten( recording, @@ -290,7 +279,7 @@ def test_whiten_regularisation_norm(self): whitening preprocessing is the same as the one computed from sklearn when regularise kwargs are given. """ - _, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32) + _, _, recording = self.get_test_recording(num_segments=1, dtype=np.float32) whitened_recording = whiten( recording, @@ -318,7 +307,7 @@ def test_local_vs_global_whiten(self): channels are considered for whitening. Test that whitening is correct for the first pair and last pair. """ - _, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32) + _, _, recording = self.get_test_recording(num_segments=1, dtype=np.float32) y_dist = 2 recording.set_channel_locations( @@ -369,7 +358,12 @@ def test_passed_W_and_M(self): be used for the actual whitening computation. """ num_chan = 4 - recording = self.get_empty_custom_recording(2, num_chan, dtype=np.float32) + num_samples = 10000 + + recording = NumpyRecording( + [np.zeros((num_samples, num_chan))] * 2, + sampling_frequency=30000, + ) test_W = np.random.normal(size=(num_chan, num_chan)) test_M = np.random.normal(size=num_chan)