Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #384 from SpikeInterface/filter_clean
Browse files Browse the repository at this point in the history
Only filter requested channel idxs
  • Loading branch information
alejoe91 authored Aug 4, 2020
2 parents 8d6269f + de4da78 commit 8a73b1c
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 73 deletions.
35 changes: 7 additions & 28 deletions spiketoolkit/preprocessing/bandpass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,18 @@
import numpy as np
from scipy import special
import spikeextractors as se
from copy import deepcopy
import scipy.signal as ss

try:
import scipy.signal as ss
HAVE_BFR = True
except ImportError:
HAVE_BFR = False


class BandpassFilterRecording(FilterRecording):

preprocessor_name = 'BandpassFilter'
installed = HAVE_BFR # check at class level if installed or not
installed = True # check at class level if installed or not
installation_mesg = "To use the BandpassFilterRecording, install scipy: \n\n pip install scipy\n\n" # err

def __init__(self, recording, freq_min=300, freq_max=6000, freq_wid=1000, filter_type='fft', order=3,
chunk_size=30000, cache_chunks=False):
assert HAVE_BFR, "To use the BandpassFilterRecording, install scipy: \n\n pip install scipy\n\n"
self._freq_min = freq_min
self._freq_max = freq_max
self._freq_wid = freq_wid
Expand All @@ -43,11 +37,11 @@ def __init__(self, recording, freq_min=300, freq_max=6000, freq_wid=1000, filter
'freq_wid': freq_wid, 'filter_type': filter_type, 'order': order,
'chunk_size': chunk_size, 'cache_chunks': cache_chunks}

def filter_chunk(self, *, start_frame, end_frame):
def filter_chunk(self, *, start_frame, end_frame, channel_ids):
padding = 3000
i1 = start_frame - padding
i2 = end_frame + padding
padded_chunk = self._read_chunk(i1, i2)
padded_chunk = self._read_chunk(i1, i2, channel_ids)
filtered_padded_chunk = self._do_filter(padded_chunk)
return filtered_padded_chunk[:, start_frame - i1:end_frame - i1]

Expand All @@ -71,22 +65,6 @@ def _do_filter(self, chunk):

return chunk_filtered

def _read_chunk(self, i1, i2):
M = len(self._recording.get_channel_ids())
N = self._recording.get_num_frames()
if i1 < 0:
i1b = 0
else:
i1b = i1
if i2 > N:
i2b = N
else:
i2b = i2
ret = np.zeros((M, i2 - i1))
ret[:, i1b - i1:i2b - i1] = self._recording.get_traces(start_frame=i1b, end_frame=i2b)

return ret


def _create_filter_kernel(N, sampling_frequency, freq_min, freq_max, freq_wid=1000):
# Matches ahb's code /matlab/processors/ms_bandpass_filter.m
Expand Down Expand Up @@ -126,7 +104,7 @@ def bandpass_filter(recording, freq_min=300, freq_max=6000, freq_wid=1000, filte
Low-pass cutoff frequency.
freq_wid: int or float
Width of the filter (when type is 'fft').
type: str
filter_type: str
'fft' or 'butter'. The 'fft' filter uses a kernel in the frequency domain. The 'butter' filter uses
scipy butter and filtfilt functions.
order: int
Expand All @@ -137,6 +115,7 @@ def bandpass_filter(recording, freq_min=300, freq_max=6000, freq_wid=1000, filte
If True, filtered traces are computed and cached all at once on disk in temp file
cache_chunks: bool (default False).
If True then each chunk is cached in memory (in a dict)
Returns
-------
filter_recording: BandpassFilterRecording
Expand All @@ -153,7 +132,7 @@ def bandpass_filter(recording, freq_min=300, freq_max=6000, freq_wid=1000, filte
filter_type=filter_type,
order=order,
chunk_size=chunk_size,
cache_chunks=cache_chunks,
cache_chunks=cache_chunks
)
if cache_to_file:
return se.CacheRecordingExtractor(bpf_recording, chunk_size=chunk_size)
Expand Down
50 changes: 38 additions & 12 deletions spiketoolkit/preprocessing/filterrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
if self._chunk_size is not None:
ich1 = int(start_frame / self._chunk_size)
ich2 = int((end_frame - 1) / self._chunk_size)
dt = self._recording.get_traces(start_frame=0, end_frame=1).dtype
dt = self.get_dtype()
filtered_chunk = np.zeros((len(channel_ids), int(end_frame-start_frame)), dtype=dt)
pos = 0
for ich in range(ich1, ich2 + 1):
filtered_chunk0 = self._get_filtered_chunk(ich)
filtered_chunk0 = self._get_filtered_chunk(ich, channel_ids)
if ich == ich1:
start0 = start_frame - ich * self._chunk_size
else:
Expand All @@ -67,37 +67,63 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
end0 = end_frame - ich * self._chunk_size
else:
end0 = self._chunk_size
chan_idx = [self.get_channel_ids().index(chan) for chan in channel_ids]
filtered_chunk[:, pos:pos+end0-start0] = filtered_chunk0[chan_idx, start0:end0]
filtered_chunk[:, pos:pos+end0-start0] = filtered_chunk0[:, start0:end0]
pos += (end0-start0)
else:
filtered_chunk = self.filter_chunk(start_frame=start_frame, end_frame=end_frame)[channel_ids, :]
filtered_chunk = self.filter_chunk(start_frame=start_frame, end_frame=end_frame, channel_ids=channel_ids)
return filtered_chunk.astype(self._dtype)

@abstractmethod
def filter_chunk(self, *, start_frame, end_frame):
def filter_chunk(self, *, start_frame, end_frame, channel_ids):
raise NotImplementedError('filter_chunk not implemented')

def _get_filtered_chunk(self, ind):
def _read_chunk(self, i1, i2, channel_ids):
num_frames = self._recording.get_num_frames()
if i1 < 0:
i1b = 0
else:
i1b = i1
if i2 > num_frames:
i2b = num_frames
else:
i2b = i2
chunk = np.zeros((len(channel_ids), i2 - i1))
chunk[:, i1b - i1:i2b - i1] = self._recording.get_traces(start_frame=i1b, end_frame=i2b,
channel_ids=channel_ids)

return chunk

def _get_filtered_chunk(self, ind, channel_ids):
if self._cache_chunks:
code = str(ind)
chunk0 = self._filtered_cache_chunks.get(code)
else:
chunk0 = None

if chunk0 is not None:
return chunk0
if chunk0.shape[0] == len(channel_ids):
return chunk0
else:
channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids])
return chunk0[channel_idxs]

start0 = ind * self._chunk_size
end0 = (ind + 1) * self._chunk_size
chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0)

if self._cache_chunks:
# filter all channels if cache_chunks is used
chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0, channel_ids=self.get_channel_ids())
self._filtered_cache_chunks.add(code, chunk1)

channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids])
chunk1 = chunk1[channel_idxs]
else:
# otherwise, only filter requested channels
chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0, channel_ids=channel_ids)

return chunk1



class FilteredChunkCache():
class FilteredChunkCache:
def __init__(self):
self._chunks_by_code = dict()
self._codes = []
Expand Down
31 changes: 5 additions & 26 deletions spiketoolkit/preprocessing/notch_filter.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
from .filterrecording import FilterRecording
import spikeextractors as se
import numpy as np

try:
import scipy.signal as ss
HAVE_NFR = True
except ImportError:
HAVE_NFR = False
import scipy.signal as ss


class NotchFilterRecording(FilterRecording):

preprocessor_name = 'NotchFilter'
installed = HAVE_NFR # check at class level if installed or not
installation_mesg = "To use the NotchFilterRecording, install scipy: \n\n pip install scipy\n\n" # error message when not installed
installed = True # check at class level if installed or not
installation_mesg = "" # error message when not installed

def __init__(self, recording, freq=3000, q=30, chunk_size=30000, cache_chunks=False):
assert HAVE_NFR, "To use the NotchFilterRecording, install scipy: \n\n pip install scipy\n\n"
self._freq = freq
self._q = q
fn = 0.5 * float(recording.get_sampling_frequency())
Expand All @@ -31,11 +25,11 @@ def __init__(self, recording, freq=3000, q=30, chunk_size=30000, cache_chunks=Fa
self._kwargs = {'recording': recording.make_serialized_dict(), 'freq': freq,
'q': q, 'chunk_size': chunk_size, 'cache_chunks': cache_chunks}

def filter_chunk(self, *, start_frame, end_frame):
def filter_chunk(self, *, start_frame, end_frame, channel_ids):
padding = 3000
i1 = start_frame - padding
i2 = end_frame + padding
padded_chunk = self._read_chunk(i1, i2)
padded_chunk = self._read_chunk(i1, i2, channel_ids)
filtered_padded_chunk = self._do_filter(padded_chunk)
return filtered_padded_chunk[:, start_frame - i1:end_frame - i1]

Expand All @@ -44,21 +38,6 @@ def _do_filter(self, chunk):

return chunk_filtered

def _read_chunk(self, i1, i2):
M = len(self._recording.get_channel_ids())
N = self._recording.get_num_frames()
if i1 < 0:
i1b = 0
else:
i1b = i1
if i2 > N:
i2b = N
else:
i2b = i2
ret = np.zeros((M, i2 - i1))
ret[:, i1b - i1:i2b - i1] = self._recording.get_traces(start_frame=i1b, end_frame=i2b)
return ret


def notch_filter(recording, freq=3000, q=30, chunk_size=30000, cache_to_file=False, cache_chunks=False):
'''
Expand Down
10 changes: 3 additions & 7 deletions spiketoolkit/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,14 @@ def _compute_whitening_matrix(self, seed):
U, S, Ut = np.linalg.svd(AAt, full_matrices=True)
W = (U @ np.diag(1 / np.sqrt(S))) @ Ut

# proposed by Alessio
# AAt = data @ data.T / data.shape[1]
# D, V = np.linalg.eig(AAt)
# W = np.dot(np.diag(1.0 / np.sqrt(D + 1e-10)), V)

return W

def filter_chunk(self, *, start_frame, end_frame):
def filter_chunk(self, *, start_frame, end_frame, channel_ids):
chan_idxs = np.array([self.get_channel_ids().index(chan) for chan in channel_ids])
chunk = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame)
chunk = chunk - np.mean(chunk, axis=1, keepdims=True)
chunk2 = self._whitening_matrix @ chunk
return chunk2
return chunk2[chan_idxs]


def whiten(recording, chunk_size=30000, cache_chunks=False, seed=0):
Expand Down
1 change: 1 addition & 0 deletions spiketoolkit/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_bandpass_filter_with_cache():
check_dumping(rec_filtered)
check_dumping(rec_filtered2)
check_dumping(rec_filtered3)
check_dumping(rec_filtered4)

shutil.rmtree('test')

Expand Down

0 comments on commit 8a73b1c

Please sign in to comment.