Skip to content

Commit

Permalink
Rework to avoid introducing a 'update_traces' function.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Nov 15, 2024
1 parent e7ac452 commit 4a5ff58
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 58 deletions.
18 changes: 0 additions & 18 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class NumpyRecording(BaseRecording):
"""

def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=None):

if isinstance(traces_list, list):
all_elements_are_list = all(isinstance(e, list) for e in traces_list)
if all_elements_are_list:
Expand Down Expand Up @@ -104,23 +103,6 @@ def from_recording(source_recording, **job_kwargs):
)
return recording

def update_traces(self, traces, segment_index=0):
"""
Set the `traces` on on the segment of index `segment_index`.
`traces` must be the same size (num_samples, num_channels)
and dtype as the recording.
"""
if traces.shape[0] != self.get_num_samples(segment_index=segment_index):
raise ValueError("The first dimension must be the same size as" "the number of samples.")

if traces.shape[1] != self.get_num_channels():
raise ValueError("The second dimension of the data be the same" "size as the number of channels.")

if traces.dtype != self.dtype:
raise ValueError("The dtype of the data must match the recording dtype.")

self._recording_segments[segment_index]._traces = traces


class NumpyRecordingSegment(BaseRecordingSegment):
def __init__(self, traces, sampling_frequency, t_start):
Expand Down
74 changes: 34 additions & 40 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TestWhiten:
returned data is indeed white.
"""

def get_float_test_data(self, num_segments, dtype, means=None):
def get_test_recording(self, num_segments, dtype, means=None):
"""
Generate a set of test data with known covariance matrix and mean.
Test data is drawn from a multivariate Gaussian distribute with
Expand All @@ -57,45 +57,37 @@ def get_float_test_data(self, num_segments, dtype, means=None):
The `means` should be an array of length 3 (num samples)
or `None`. If `None`, means will be zero.
"""
sampling_frequency = 30000
num_samples = int(10 * sampling_frequency) # 10 s recording

means, cov_mat, data = self.get_test_data_with_known_distribution(num_samples, dtype, means)

recording = NumpyRecording([data], sampling_frequency)

return means, cov_mat, recording

def get_test_data_with_known_distribution(self, num_samples, dtype, means=None):
""" """
num_channels = 3

if means is None:
means = np.zeros(num_channels)

cov_mat = np.array([[1, 0.5, 0], [0.5, 1, -0.25], [0, -0.25, 1]])

# Generate recording and multivariate Gaussian data to set
recording = self.get_empty_custom_recording(num_segments, num_channels, dtype)

seg_1_data = np.random.multivariate_normal(means, cov_mat, recording.get_num_samples(segment_index=0))
data = np.random.multivariate_normal(means, cov_mat, num_samples)

# Set the dtype, if `int16`, first scale to +/- 1 then cast to int16 range.
if dtype == np.float32:
seg_1_data = seg_1_data.astype(dtype)
data = data.astype(dtype)

elif dtype == np.int16:
seg_1_data /= seg_1_data.max()
seg_1_data = np.round((seg_1_data) * 32767).astype(np.int16)
data /= data.max()
data = np.round((data) * 32767).astype(np.int16)
else:
raise ValueError("dtype must be float32 or int16")

# Set the data on the recording and return
recording.update_traces(seg_1_data)
assert np.array_equal(recording.get_traces(segment_index=0), seg_1_data), "segment 1 test setup did not work."

return means, cov_mat, recording

def get_empty_custom_recording(self, num_segments, num_channels, dtype):

sampling_frequency = 30000
num_samples = int(10 * sampling_frequency)

traces_list = [np.zeros((num_samples, num_channels), dtype=dtype) for _ in range(num_segments)]

return NumpyRecording(
traces_list,
sampling_frequency=30000,
)
return means, cov_mat, data

def cov_mat_from_whitening_mat(self, whitened_recording, eps):
"""
Expand Down Expand Up @@ -147,7 +139,7 @@ def test_compute_covariance_matrix(self, dtype):
otherwise it can overflow.
"""
eps = 1e-16
_, cov_mat, recording = self.get_float_test_data(num_segments=1, dtype=dtype)
_, cov_mat, recording = self.get_test_recording(num_segments=1, dtype=dtype)

whitened_recording = whiten(
recording,
Expand Down Expand Up @@ -177,7 +169,7 @@ def test_non_default_eps(self):
the cov mat if the correct eps is used.
"""
eps = 1
_, cov_mat, recording = self.get_float_test_data(num_segments=1, dtype=np.float32)
_, cov_mat, recording = self.get_test_recording(num_segments=1, dtype=np.float32)

whitened_recording = whiten(
recording,
Expand All @@ -200,17 +192,14 @@ def test_compute_covariance_matrix_2_segments(self):
but the covariance matrix is scaled by 1 / N.
"""
eps = 1e-16
_, cov_mat, recording = self.get_float_test_data(num_segments=2, dtype=np.float32)
sampling_frequency = 30000
num_samples = 10 * sampling_frequency

all_zero_data = np.zeros(
(recording.get_num_samples(segment_index=0), recording.get_num_channels()),
dtype=np.float32,
)
_, cov_mat, data = self.get_test_data_with_known_distribution(num_samples, np.float32)

recording.update_traces(
all_zero_data,
segment_index=1,
)
traces_list = [data, np.zeros_like(data)]

recording = NumpyRecording(traces_list, sampling_frequency)

whitened_recording = whiten(
recording,
Expand Down Expand Up @@ -238,7 +227,7 @@ def test_apply_mean(self, apply_mean):
means = np.array([10, 20, 30])

eps = 1e-16
_, cov_mat, recording = self.get_float_test_data(num_segments=1, dtype=np.float32, means=means)
_, cov_mat, recording = self.get_test_recording(num_segments=1, dtype=np.float32, means=means)

whitened_recording = whiten(
recording,
Expand Down Expand Up @@ -290,7 +279,7 @@ def test_whiten_regularisation_norm(self):
whitening preprocessing is the same as the one
computed from sklearn when regularise kwargs are given.
"""
_, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32)
_, _, recording = self.get_test_recording(num_segments=1, dtype=np.float32)

whitened_recording = whiten(
recording,
Expand Down Expand Up @@ -318,7 +307,7 @@ def test_local_vs_global_whiten(self):
channels are considered for whitening. Test that whitening
is correct for the first pair and last pair.
"""
_, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32)
_, _, recording = self.get_test_recording(num_segments=1, dtype=np.float32)

y_dist = 2
recording.set_channel_locations(
Expand Down Expand Up @@ -369,7 +358,12 @@ def test_passed_W_and_M(self):
be used for the actual whitening computation.
"""
num_chan = 4
recording = self.get_empty_custom_recording(2, num_chan, dtype=np.float32)
num_samples = 10000

recording = NumpyRecording(
[np.zeros((num_samples, num_chan))] * 2,
sampling_frequency=30000,
)

test_W = np.random.normal(size=(num_chan, num_chan))
test_M = np.random.normal(size=num_chan)
Expand Down

0 comments on commit 4a5ff58

Please sign in to comment.