Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Aug 4, 2020
1 parent fcb9fba commit de4da78
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
21 changes: 16 additions & 5 deletions spiketoolkit/preprocessing/filterrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
end0 = end_frame - ich * self._chunk_size
else:
end0 = self._chunk_size
chan_idx = [self.get_channel_ids().index(chan) for chan in channel_ids]
filtered_chunk[:, pos:pos+end0-start0] = filtered_chunk0[chan_idx, start0:end0]
filtered_chunk[:, pos:pos+end0-start0] = filtered_chunk0[:, start0:end0]
pos += (end0-start0)
else:
filtered_chunk = self.filter_chunk(start_frame=start_frame, end_frame=end_frame, channel_ids=channel_ids)
Expand Down Expand Up @@ -100,15 +99,27 @@ def _get_filtered_chunk(self, ind, channel_ids):
chunk0 = self._filtered_cache_chunks.get(code)
else:
chunk0 = None

if chunk0 is not None:
return chunk0
if chunk0.shape[0] == len(channel_ids):
return chunk0
else:
channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids])
return chunk0[channel_idxs]

start0 = ind * self._chunk_size
end0 = (ind + 1) * self._chunk_size
chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0, channel_ids=channel_ids)

if self._cache_chunks:
# filter all channels if cache_chunks is used
chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0, channel_ids=self.get_channel_ids())
self._filtered_cache_chunks.add(code, chunk1)

channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids])
chunk1 = chunk1[channel_idxs]
else:
# otherwise, only filter requested channels
chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0, channel_ids=channel_ids)

return chunk1


Expand Down
10 changes: 3 additions & 7 deletions spiketoolkit/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,14 @@ def _compute_whitening_matrix(self, seed):
U, S, Ut = np.linalg.svd(AAt, full_matrices=True)
W = (U @ np.diag(1 / np.sqrt(S))) @ Ut

# proposed by Alessio
# AAt = data @ data.T / data.shape[1]
# D, V = np.linalg.eig(AAt)
# W = np.dot(np.diag(1.0 / np.sqrt(D + 1e-10)), V)

return W

def filter_chunk(self, *, start_frame, end_frame, channel_ids):
chunk = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame, channel_ids=channel_ids)
chan_idxs = np.array([self.get_channel_ids().index(chan) for chan in channel_ids])
chunk = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame)
chunk = chunk - np.mean(chunk, axis=1, keepdims=True)
chunk2 = self._whitening_matrix @ chunk
return chunk2
return chunk2[chan_idxs]


def whiten(recording, chunk_size=30000, cache_chunks=False, seed=0):
Expand Down
1 change: 1 addition & 0 deletions spiketoolkit/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_bandpass_filter_with_cache():
check_dumping(rec_filtered)
check_dumping(rec_filtered2)
check_dumping(rec_filtered3)
check_dumping(rec_filtered4)

shutil.rmtree('test')

Expand Down

0 comments on commit de4da78

Please sign in to comment.