Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #465 from SpikeInterface/fix_return_scaled
Browse files Browse the repository at this point in the history
Fix return scaled
  • Loading branch information
alejoe91 authored Mar 10, 2021
2 parents bab9616 + 7426c5a commit 3fc4099
Show file tree
Hide file tree
Showing 15 changed files with 45 additions and 27 deletions.
4 changes: 2 additions & 2 deletions spiketoolkit/preprocessing/bandpass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def __init__(self, recording, freq_min=300, freq_max=6000, freq_wid=1000, filter
'freq_wid': freq_wid, 'filter_type': filter_type, 'order': order,
'chunk_size': chunk_size, 'cache_chunks': cache_chunks}

def filter_chunk(self, *, start_frame, end_frame, channel_ids):
def filter_chunk(self, start_frame, end_frame, channel_ids, return_scaled):
padding = 3000
i1 = start_frame - padding
i2 = end_frame + padding
padded_chunk = self._read_chunk(i1, i2, channel_ids)
padded_chunk = self._read_chunk(i1, i2, channel_ids, return_scaled)
filtered_padded_chunk = self._do_filter(padded_chunk)
return filtered_padded_chunk[:, start_frame - i1:end_frame - i1]

Expand Down
1 change: 1 addition & 0 deletions spiketoolkit/preprocessing/basepreprocessorrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, recording):
self.copy_epochs(recording)
self.copy_times(recording)

# avoid rescaling twice
self.set_channel_gains(1)
self.set_channel_offsets(0)

Expand Down
4 changes: 4 additions & 0 deletions spiketoolkit/preprocessing/blank_saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __init__(self, recording, threshold=None, seed=0):
self._lower = False
else:
self._lower = True
self.has_unscaled = False

self._kwargs = {'recording': recording.make_serialized_dict(), 'threshold': threshold, 'seed': seed}

def _get_random_data_for_scaling(self, num_chunks=50, chunk_size=500, seed=0):
Expand All @@ -43,6 +45,8 @@ def _get_random_data_for_scaling(self, num_chunks=50, chunk_size=500, seed=0):

@check_get_traces_args
def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):
assert return_scaled, "'blank_saturation' only supports return_scaled=True"

traces = self._recording.get_traces(channel_ids=channel_ids,
start_frame=start_frame,
end_frame=end_frame,
Expand Down
6 changes: 1 addition & 5 deletions spiketoolkit/preprocessing/center.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ def __init__(self, recording, mode, seconds, n_snippets):
self._offset = -np.mean(traces, axis=1)
else:
self._offset = -np.median(traces, axis=1)
dtype = str(recording.get_dtype())
dtype = np.dtype(recording.get_dtype()).name
if 'uint' in dtype:
if 'numpy' in dtype:
dtype = str(dtype).replace("<class '", "").replace("'>", "")
# drop 'numpy'
dtype = dtype.split('.')[1]
dtype = dtype[1:]
TransformRecording.__init__(self, recording, scalar=self._scalar, offset=self._offset, dtype=dtype)
self._kwargs = {'recording': recording.make_serialized_dict(), 'mode': mode, 'seconds': seconds,
Expand Down
3 changes: 3 additions & 0 deletions spiketoolkit/preprocessing/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ def __init__(self, recording, a_min=None, a_max=None):
self._a_min = a_min
self._a_max = a_max
BasePreprocessorRecordingExtractor.__init__(self, recording)
self.has_unscaled = False
self._kwargs = {'recording': recording.make_serialized_dict(), 'a_min': a_min, 'a_max': a_max}

@check_get_traces_args
def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):
assert return_scaled, "'clip' only supports return_scaled=True"

traces = self._recording.get_traces(channel_ids=channel_ids,
start_frame=start_frame,
end_frame=end_frame,
Expand Down
6 changes: 4 additions & 2 deletions spiketoolkit/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_
if self._groups is None:
if self.verbose:
print('Common median reference using all channels')
traces = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame, return_scaled=return_scaled)
traces = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame,
return_scaled=return_scaled)
traces = traces - np.median(traces, axis=0, keepdims=True)
return traces[channel_idxs].astype(self._dtype)
else:
Expand All @@ -68,7 +69,8 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_
if self.verbose:
print('Common average reference using all channels')
if self._groups is None:
traces = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame, return_scaled=return_scaled)
traces = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame,
return_scaled=return_scaled)
traces = traces - np.mean(traces, axis=0, keepdims=True)
return traces[channel_idxs].astype(self._dtype)
else:
Expand Down
13 changes: 7 additions & 6 deletions spiketoolkit/preprocessing/filterrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_
filtered_chunk = np.zeros((len(channel_ids), int(end_frame-start_frame)), dtype=dt)
pos = 0
for ich in range(ich1, ich2 + 1):
filtered_chunk0 = self._get_filtered_chunk(ich, channel_ids)
filtered_chunk0 = self._get_filtered_chunk(ich, channel_ids, return_scaled)
if ich == ich1:
start0 = start_frame - ich * self._chunk_size
else:
Expand All @@ -62,10 +62,10 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_
return filtered_chunk.astype(self._dtype)

@abstractmethod
def filter_chunk(self, *, start_frame, end_frame, channel_ids):
def filter_chunk(self, *, start_frame, end_frame, channel_ids, return_scaled):
raise NotImplementedError('filter_chunk not implemented')

def _read_chunk(self, i1, i2, channel_ids):
def _read_chunk(self, i1, i2, channel_ids, return_scaled=True):
num_frames = self._recording.get_num_frames()
if i1 < 0:
i1b = 0
Expand All @@ -77,11 +77,11 @@ def _read_chunk(self, i1, i2, channel_ids):
i2b = i2
chunk = np.zeros((len(channel_ids), i2 - i1))
chunk[:, i1b - i1:i2b - i1] = self._recording.get_traces(start_frame=i1b, end_frame=i2b,
channel_ids=channel_ids)
channel_ids=channel_ids, return_scaled=return_scaled)

return chunk

def _get_filtered_chunk(self, ind, channel_ids):
def _get_filtered_chunk(self, ind, channel_ids, return_scaled):
if self._cache_chunks:
code = str(ind)
chunk0 = self._filtered_cache_chunks.get(code)
Expand All @@ -106,7 +106,8 @@ def _get_filtered_chunk(self, ind, channel_ids):
chunk1 = chunk1[channel_idxs]
else:
# otherwise, only filter requested channels
chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0, channel_ids=channel_ids)
chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0, channel_ids=channel_ids,
return_scaled=return_scaled)

return chunk1

Expand Down
4 changes: 2 additions & 2 deletions spiketoolkit/preprocessing/highpass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def __init__(self, recording, freq_min=300, freq_wid=1000, filter_type='butter',
'freq_wid': freq_wid, 'filter_type': filter_type, 'order': order,
'chunk_size': chunk_size, 'cache_chunks': cache_chunks}

def filter_chunk(self, *, start_frame, end_frame, channel_ids):
def filter_chunk(self, *, start_frame, end_frame, channel_ids, return_scaled):
padding = 3000
i1 = start_frame - self._padding
i2 = end_frame + self._padding
padded_chunk = self._read_chunk(i1, i2, channel_ids)
padded_chunk = self._read_chunk(i1, i2, channel_ids, return_scaled)
filtered_padded_chunk = self._do_filter(padded_chunk)
return filtered_padded_chunk[:, start_frame - i1:end_frame - i1]

Expand Down
2 changes: 1 addition & 1 deletion spiketoolkit/preprocessing/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_
traces = self._recording.get_traces(channel_ids=channel_ids,
start_frame=start_frame,
end_frame=end_frame,
return_scaled=True)
return_scaled=return_scaled)

traces = traces.copy() # takes care of memmap objects
traces[:, ~self._mask[start_frame:end_frame]] = 0.0
Expand Down
5 changes: 4 additions & 1 deletion spiketoolkit/preprocessing/normalize_by_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, recording, scale=1.0, median=0.0, q1=0.01, q2=0.99, seed=0):

self._scalar = scale / pre_scale
self._offset = median - pre_median * self._scalar
self.has_unscaled = False
self._kwargs = {'recording': recording.make_serialized_dict(), 'scale': scale, 'median': median,
'q1': q1, 'q2': q2, 'seed': seed}

Expand All @@ -31,10 +32,12 @@ def _get_random_data_for_scaling(self, num_chunks=50, chunk_size=500, seed=0):

@check_get_traces_args
def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):
assert return_scaled, "'normalize_by_quantile' only supports return_scaled=True"

traces = self._recording.get_traces(channel_ids=channel_ids,
start_frame=start_frame,
end_frame=end_frame,
return_scaled=True)
return_scaled=return_scaled)
return traces * self._scalar + self._offset


Expand Down
4 changes: 2 additions & 2 deletions spiketoolkit/preprocessing/notch_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def __init__(self, recording, freq=3000, q=30, chunk_size=30000, cache_chunks=Fa
self._kwargs = {'recording': recording.make_serialized_dict(), 'freq': freq,
'q': q, 'chunk_size': chunk_size, 'cache_chunks': cache_chunks}

def filter_chunk(self, *, start_frame, end_frame, channel_ids):
def filter_chunk(self, start_frame, end_frame, channel_ids, return_scaled):
padding = 3000
i1 = start_frame - padding
i2 = end_frame + padding
padded_chunk = self._read_chunk(i1, i2, channel_ids)
padded_chunk = self._read_chunk(i1, i2, channel_ids, return_scaled)
filtered_padded_chunk = self._do_filter(padded_chunk)
return filtered_padded_chunk[:, start_frame - i1:end_frame - i1]

Expand Down
8 changes: 5 additions & 3 deletions spiketoolkit/preprocessing/remove_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def __init__(self, recording, triggers, ms_before=0.5, ms_after=3.0, mode='zeros

@check_get_traces_args
def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):
traces = self._recording.get_traces(channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame,
traces = self._recording.get_traces(channel_ids=channel_ids,
start_frame=start_frame,
end_frame=end_frame,
return_scaled=return_scaled)
triggers = self._triggers[(self._triggers > start_frame) & (self._triggers < end_frame)] - start_frame

Expand Down Expand Up @@ -107,13 +109,13 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_

interp_traces = np.vstack((pre_vals, post_vals)).T

if (self._mode == 'cubic') & (len(all_idx) >= 5):
if self._mode == 'cubic' and len(all_idx) >= 5:
# Enough fit points present on either side to do cubic spline fit:
interp_function = interp1d(all_idx, interp_traces, self._mode,
bounds_error=False,
fill_value='extrapolate')
traces[:, gap_idx] = interp_function(gap_idx)
elif (self._mode == 'linear') & (len(all_idx) >= 2):
elif self._mode == 'linear' and len(all_idx) >= 2:
# Enough fit points present for a linear fit
interp_function = interp1d(all_idx, interp_traces, self._mode, bounds_error=False,
fill_value='extrapolate')
Expand Down
2 changes: 1 addition & 1 deletion spiketoolkit/preprocessing/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_num_frames(self):
return int(self._recording.get_num_frames() / self._recording.get_sampling_frequency() * self._resample_rate)

# avoid filtering one sample
def get_dtype(self):
def get_dtype(self, return_scaled=True):
return self._dtype

@check_get_traces_args
Expand Down
3 changes: 3 additions & 0 deletions spiketoolkit/preprocessing/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ def __init__(self, recording, scalar=1., offset=0., dtype=None):
else:
self._dtype = dtype
BasePreprocessorRecordingExtractor.__init__(self, recording)
self.has_unscaled = False

self._kwargs = {'recording': recording.make_serialized_dict(), 'scalar': scalar, 'offset': offset,
'dtype': dtype}

@check_get_traces_args
def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):
assert return_scaled, "'transform' only supports return_scaled=True"

traces = self._recording.get_traces(channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame,
return_scaled=return_scaled)
if isinstance(self._scalar, (int, float, np.integer, np.float)):
Expand Down
7 changes: 5 additions & 2 deletions spiketoolkit/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class WhitenRecording(FilterRecording):
def __init__(self, recording, chunk_size=30000, cache_chunks=False, seed=0):
FilterRecording.__init__(self, recording=recording, chunk_size=chunk_size, cache_chunks=cache_chunks)
self._whitening_matrix = self._compute_whitening_matrix(seed=seed)
self.has_unscaled = False
self._kwargs = {'recording': recording.make_serialized_dict(), 'chunk_size': chunk_size,
'cache_chunks': cache_chunks, 'seed': seed}

Expand Down Expand Up @@ -35,9 +36,11 @@ def _compute_whitening_matrix(self, seed):

return W

def filter_chunk(self, *, start_frame, end_frame, channel_ids):
def filter_chunk(self, start_frame, end_frame, channel_ids, return_scaled):
assert return_scaled, "'whiten' only supports return_scaled=True"

chan_idxs = np.array([self.get_channel_ids().index(chan) for chan in channel_ids])
chunk = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame)
chunk = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame, return_scaled=return_scaled)
chunk = chunk - np.mean(chunk, axis=1, keepdims=True)
chunk2 = self._whitening_matrix @ chunk
return chunk2[chan_idxs]
Expand Down

0 comments on commit 3fc4099

Please sign in to comment.