diff --git a/spiketoolkit/preprocessing/bandpass_filter.py b/spiketoolkit/preprocessing/bandpass_filter.py index cf0d6166..9554e56b 100644 --- a/spiketoolkit/preprocessing/bandpass_filter.py +++ b/spiketoolkit/preprocessing/bandpass_filter.py @@ -2,24 +2,18 @@ import numpy as np from scipy import special import spikeextractors as se -from copy import deepcopy +import scipy.signal as ss -try: - import scipy.signal as ss - HAVE_BFR = True -except ImportError: - HAVE_BFR = False class BandpassFilterRecording(FilterRecording): preprocessor_name = 'BandpassFilter' - installed = HAVE_BFR # check at class level if installed or not + installed = True # check at class level if installed or not installation_mesg = "To use the BandpassFilterRecording, install scipy: \n\n pip install scipy\n\n" # err def __init__(self, recording, freq_min=300, freq_max=6000, freq_wid=1000, filter_type='fft', order=3, chunk_size=30000, cache_chunks=False): - assert HAVE_BFR, "To use the BandpassFilterRecording, install scipy: \n\n pip install scipy\n\n" self._freq_min = freq_min self._freq_max = freq_max self._freq_wid = freq_wid @@ -43,11 +37,11 @@ def __init__(self, recording, freq_min=300, freq_max=6000, freq_wid=1000, filter 'freq_wid': freq_wid, 'filter_type': filter_type, 'order': order, 'chunk_size': chunk_size, 'cache_chunks': cache_chunks} - def filter_chunk(self, *, start_frame, end_frame): + def filter_chunk(self, *, start_frame, end_frame, channel_ids): padding = 3000 i1 = start_frame - padding i2 = end_frame + padding - padded_chunk = self._read_chunk(i1, i2) + padded_chunk = self._read_chunk(i1, i2, channel_ids) filtered_padded_chunk = self._do_filter(padded_chunk) return filtered_padded_chunk[:, start_frame - i1:end_frame - i1] @@ -71,22 +65,6 @@ def _do_filter(self, chunk): return chunk_filtered - def _read_chunk(self, i1, i2): - M = len(self._recording.get_channel_ids()) - N = self._recording.get_num_frames() - if i1 < 0: - i1b = 0 - else: - i1b = i1 - if i2 > N: - i2b = N - else: - i2b = i2 - ret = np.zeros((M, i2 - i1)) - ret[:, i1b - i1:i2b - i1] = self._recording.get_traces(start_frame=i1b, end_frame=i2b) - - return ret - def _create_filter_kernel(N, sampling_frequency, freq_min, freq_max, freq_wid=1000): # Matches ahb's code /matlab/processors/ms_bandpass_filter.m @@ -126,7 +104,7 @@ def bandpass_filter(recording, freq_min=300, freq_max=6000, freq_wid=1000, filte Low-pass cutoff frequency. freq_wid: int or float Width of the filter (when type is 'fft'). - type: str + filter_type: str 'fft' or 'butter'. The 'fft' filter uses a kernel in the frequency domain. The 'butter' filter uses scipy butter and filtfilt functions. order: int @@ -137,6 +115,7 @@ def bandpass_filter(recording, freq_min=300, freq_max=6000, freq_wid=1000, filte If True, filtered traces are computed and cached all at once on disk in temp file cache_chunks: bool (default False). If True then each chunk is cached in memory (in a dict) + Returns ------- filter_recording: BandpassFilterRecording @@ -153,7 +132,7 @@ def bandpass_filter(recording, freq_min=300, freq_max=6000, freq_wid=1000, filte filter_type=filter_type, order=order, chunk_size=chunk_size, - cache_chunks=cache_chunks, + cache_chunks=cache_chunks ) if cache_to_file: return se.CacheRecordingExtractor(bpf_recording, chunk_size=chunk_size) diff --git a/spiketoolkit/preprocessing/filterrecording.py b/spiketoolkit/preprocessing/filterrecording.py index ec81bb69..08b17658 100644 --- a/spiketoolkit/preprocessing/filterrecording.py +++ b/spiketoolkit/preprocessing/filterrecording.py @@ -54,11 +54,11 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None): if self._chunk_size is not None: ich1 = int(start_frame / self._chunk_size) ich2 = int((end_frame - 1) / self._chunk_size) - dt = self._recording.get_traces(start_frame=0, end_frame=1).dtype + dt = self.get_dtype() filtered_chunk = np.zeros((len(channel_ids), int(end_frame-start_frame)), dtype=dt) pos = 0 for ich in range(ich1, ich2 + 1): - filtered_chunk0 = self._get_filtered_chunk(ich) + filtered_chunk0 = self._get_filtered_chunk(ich, channel_ids) if ich == ich1: start0 = start_frame - ich * self._chunk_size else: @@ -67,37 +67,63 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None): end0 = end_frame - ich * self._chunk_size else: end0 = self._chunk_size - chan_idx = [self.get_channel_ids().index(chan) for chan in channel_ids] - filtered_chunk[:, pos:pos+end0-start0] = filtered_chunk0[chan_idx, start0:end0] + filtered_chunk[:, pos:pos+end0-start0] = filtered_chunk0[:, start0:end0] pos += (end0-start0) else: - filtered_chunk = self.filter_chunk(start_frame=start_frame, end_frame=end_frame)[channel_ids, :] + filtered_chunk = self.filter_chunk(start_frame=start_frame, end_frame=end_frame, channel_ids=channel_ids) return filtered_chunk.astype(self._dtype) @abstractmethod - def filter_chunk(self, *, start_frame, end_frame): + def filter_chunk(self, *, start_frame, end_frame, channel_ids): raise NotImplementedError('filter_chunk not implemented') - def _get_filtered_chunk(self, ind): + def _read_chunk(self, i1, i2, channel_ids): + num_frames = self._recording.get_num_frames() + if i1 < 0: + i1b = 0 + else: + i1b = i1 + if i2 > num_frames: + i2b = num_frames + else: + i2b = i2 + chunk = np.zeros((len(channel_ids), i2 - i1)) + chunk[:, i1b - i1:i2b - i1] = self._recording.get_traces(start_frame=i1b, end_frame=i2b, + channel_ids=channel_ids) + + return chunk + + def _get_filtered_chunk(self, ind, channel_ids): if self._cache_chunks: code = str(ind) chunk0 = self._filtered_cache_chunks.get(code) else: chunk0 = None + if chunk0 is not None: - return chunk0 + if chunk0.shape[0] == len(channel_ids): + return chunk0 + else: + channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids]) + return chunk0[channel_idxs] start0 = ind * self._chunk_size end0 = (ind + 1) * self._chunk_size - chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0) + if self._cache_chunks: + # filter all channels if cache_chunks is used + chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0, channel_ids=self.get_channel_ids()) self._filtered_cache_chunks.add(code, chunk1) - + channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids]) + chunk1 = chunk1[channel_idxs] + else: + # otherwise, only filter requested channels + chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0, channel_ids=channel_ids) + return chunk1 - -class FilteredChunkCache(): +class FilteredChunkCache: def __init__(self): self._chunks_by_code = dict() self._codes = [] diff --git a/spiketoolkit/preprocessing/notch_filter.py b/spiketoolkit/preprocessing/notch_filter.py index 1b263ded..f9b2f98d 100644 --- a/spiketoolkit/preprocessing/notch_filter.py +++ b/spiketoolkit/preprocessing/notch_filter.py @@ -1,22 +1,16 @@ from .filterrecording import FilterRecording import spikeextractors as se import numpy as np - -try: - import scipy.signal as ss - HAVE_NFR = True -except ImportError: - HAVE_NFR = False +import scipy.signal as ss class NotchFilterRecording(FilterRecording): preprocessor_name = 'NotchFilter' - installed = HAVE_NFR # check at class level if installed or not - installation_mesg = "To use the NotchFilterRecording, install scipy: \n\n pip install scipy\n\n" # error message when not installed + installed = True # check at class level if installed or not + installation_mesg = "" # error message when not installed def __init__(self, recording, freq=3000, q=30, chunk_size=30000, cache_chunks=False): - assert HAVE_NFR, "To use the NotchFilterRecording, install scipy: \n\n pip install scipy\n\n" self._freq = freq self._q = q fn = 0.5 * float(recording.get_sampling_frequency()) @@ -31,11 +25,11 @@ def __init__(self, recording, freq=3000, q=30, chunk_size=30000, cache_chunks=Fa self._kwargs = {'recording': recording.make_serialized_dict(), 'freq': freq, 'q': q, 'chunk_size': chunk_size, 'cache_chunks': cache_chunks} - def filter_chunk(self, *, start_frame, end_frame): + def filter_chunk(self, *, start_frame, end_frame, channel_ids): padding = 3000 i1 = start_frame - padding i2 = end_frame + padding - padded_chunk = self._read_chunk(i1, i2) + padded_chunk = self._read_chunk(i1, i2, channel_ids) filtered_padded_chunk = self._do_filter(padded_chunk) return filtered_padded_chunk[:, start_frame - i1:end_frame - i1] @@ -44,21 +38,6 @@ def _do_filter(self, chunk): return chunk_filtered - def _read_chunk(self, i1, i2): - M = len(self._recording.get_channel_ids()) - N = self._recording.get_num_frames() - if i1 < 0: - i1b = 0 - else: - i1b = i1 - if i2 > N: - i2b = N - else: - i2b = i2 - ret = np.zeros((M, i2 - i1)) - ret[:, i1b - i1:i2b - i1] = self._recording.get_traces(start_frame=i1b, end_frame=i2b) - return ret - def notch_filter(recording, freq=3000, q=30, chunk_size=30000, cache_to_file=False, cache_chunks=False): ''' diff --git a/spiketoolkit/preprocessing/whiten.py b/spiketoolkit/preprocessing/whiten.py index 3c5acfc1..c018dbba 100644 --- a/spiketoolkit/preprocessing/whiten.py +++ b/spiketoolkit/preprocessing/whiten.py @@ -39,18 +39,14 @@ def _compute_whitening_matrix(self, seed): U, S, Ut = np.linalg.svd(AAt, full_matrices=True) W = (U @ np.diag(1 / np.sqrt(S))) @ Ut - # proposed by Alessio - # AAt = data @ data.T / data.shape[1] - # D, V = np.linalg.eig(AAt) - # W = np.dot(np.diag(1.0 / np.sqrt(D + 1e-10)), V) - return W - def filter_chunk(self, *, start_frame, end_frame): + def filter_chunk(self, *, start_frame, end_frame, channel_ids): + chan_idxs = np.array([self.get_channel_ids().index(chan) for chan in channel_ids]) chunk = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame) chunk = chunk - np.mean(chunk, axis=1, keepdims=True) chunk2 = self._whitening_matrix @ chunk - return chunk2 + return chunk2[chan_idxs] def whiten(recording, chunk_size=30000, cache_chunks=False, seed=0): diff --git a/spiketoolkit/tests/test_preprocessing.py b/spiketoolkit/tests/test_preprocessing.py index 1f287a75..e37e112b 100644 --- a/spiketoolkit/tests/test_preprocessing.py +++ b/spiketoolkit/tests/test_preprocessing.py @@ -71,6 +71,7 @@ def test_bandpass_filter_with_cache(): check_dumping(rec_filtered) check_dumping(rec_filtered2) check_dumping(rec_filtered3) + check_dumping(rec_filtered4) shutil.rmtree('test')