Skip to content

Commit

Permalink
Pep8 for baserawio I
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 authored Nov 25, 2023
1 parent e37823e commit bbccb4a
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions neo/rawio/baserawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class BaseRawIO:

rawmode = None # one key from possible_raw_modes

def __init__(self, use_cache: bool=False, cache_path='same_as_resource', **kargs):
def __init__(self, use_cache: bool = False, cache_path = 'same_as_resource', **kargs):
"""
:TODO: Why multi-file would have a single filename is confusing here - shouldn't
the name of this argument be filenames_list or filenames_base or similar?
Expand Down Expand Up @@ -474,7 +474,7 @@ def channel_id_to_index(self, stream_index: int, channel_ids: list[str]):
channel_indexes = np.array([chan_ids.index(chan_id) for chan_id in channel_ids])
return channel_indexes

def _get_channel_indexes(self, stream_index:int, channel_indexes: list[int]|None, channel_names: list[str]|None, channel_ids: list[str]|None):
def _get_channel_indexes(self, stream_index: int, channel_indexes: list[int] | None, channel_names: list[str] | None, channel_ids: list[str] | None):
"""
Select channel_indexes for a stream based on channel_indexes/channel_names/channel_ids
depending which is not None.
Expand All @@ -485,7 +485,7 @@ def _get_channel_indexes(self, stream_index:int, channel_indexes: list[int]|None
channel_indexes = self.channel_id_to_index(stream_index, channel_ids)
return channel_indexes

def _get_stream_index_from_arg(self, stream_index_arg):
def _get_stream_index_from_arg(self, stream_index_arg: int | None):
if stream_index_arg is None:
assert self.header['signal_streams'].size == 1
stream_index = 0
Expand All @@ -494,7 +494,7 @@ def _get_stream_index_from_arg(self, stream_index_arg):
stream_index = stream_index_arg
return stream_index

def get_signal_size(self, block_index: int, seg_index: int, stream_index: int|None=None):
def get_signal_size(self, block_index: int, seg_index: int, stream_index: int | None = None):
"""
Retrieve the length of a single section of the channels in a stream.
:param block_index:
Expand All @@ -505,7 +505,7 @@ def get_signal_size(self, block_index: int, seg_index: int, stream_index: int|No
stream_index = self._get_stream_index_from_arg(stream_index)
return self._get_signal_size(block_index, seg_index, stream_index)

def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int|None=None):
def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int | None = None):
"""
Retrieve the t_start of a single section of the channels in a stream.
:param block_index:
Expand All @@ -516,7 +516,7 @@ def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int
stream_index = self._get_stream_index_from_arg(stream_index)
return self._get_signal_t_start(block_index, seg_index, stream_index)

def get_signal_sampling_rate(self, stream_index: int| None=None):
def get_signal_sampling_rate(self, stream_index: int | None = None):
"""
Retrieve sampling rate for a stream and all channels in that stream.
:param stream_index:
Expand All @@ -529,9 +529,9 @@ def get_signal_sampling_rate(self, stream_index: int| None=None):
sr = signal_channels[0]['sampling_rate']
return float(sr)

def get_analogsignal_chunk(self, block_index: int =0, seg_index: int =0, i_start: int|None =None, i_stop: int|None =None,
stream_index: int|None =None, channel_indexes: list[int]|None =None, channel_names: list[str]|None =None,
channel_ids: list[str]|None=None, prefer_slice:bool=False):
def get_analogsignal_chunk(self, block_index: int = 0, seg_index: int = 0, i_start: int | None = None, i_stop: int | None = None,
stream_index: int | None = None, channel_indexes: list[int] | None = None, channel_names: list[str] | None = None,
channel_ids: list[str] | None = None, prefer_slice: bool = False):
"""
Return a chunk of raw signal as a Numpy array. columns correspond to samples from a
section of a single channel of recording. The channels are chosen either by channel_names,
Expand Down Expand Up @@ -588,8 +588,8 @@ def get_analogsignal_chunk(self, block_index: int =0, seg_index: int =0, i_start

return raw_chunk

def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype='float32', stream_index: int|None=None,
channel_indexes: list[int]|None=None, channel_names: list[str]|None=None, channel_ids: list[str]|None=None):
def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype = 'float32', stream_index: int | None = None,
channel_indexes: list[int] | None = None, channel_names: list[str] | None = None, channel_ids: list[str] | None = None):
"""
Rescale a chunk of raw signals which are provided as a Numpy array. These are normally
returned by a call to get_analogsignal_chunk. The channels are specified either by
Expand Down Expand Up @@ -628,11 +628,11 @@ def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype='f
return float_signal

# spiketrain and unit zone
def spike_count(self, block_index: int=0, seg_index: int=0, spike_channel_index:int =0):
def spike_count(self, block_index: int = 0, seg_index: int = 0, spike_channel_index: int = 0):
return self._spike_count(block_index, seg_index, spike_channel_index)

def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0,
t_start=None, t_stop=None):
def get_spike_timestamps(self, block_index:int = 0, seg_index: int = 0, spike_channel_index: int = 0,
t_start: float | None = None, t_stop: float | None = None):
"""
The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
Sometimes it is the index on the signal but not always.
Expand All @@ -644,21 +644,21 @@ def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0
spike_channel_index, t_start, t_stop)
return timestamp

def rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype='float64'):
def rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype = 'float64'):
"""
Rescale spike timestamps to seconds.
"""
return self._rescale_spike_timestamp(spike_timestamps, dtype)

# spiketrain waveform zone
def get_spike_raw_waveforms(self, block_index: int=0, seg_index: int=0, spike_channel_index: int=0,
t_start: float|None =None, t_stop: float|None=None):
def get_spike_raw_waveforms(self, block_index: int = 0, seg_index: int = 0, spike_channel_index: int = 0,
t_start: float | None = None, t_stop: float | None = None):
wf = self._get_spike_raw_waveforms(block_index, seg_index,
spike_channel_index, t_start, t_stop)
return wf

def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype ='float32',
spike_channel_index: int =0):
def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype = 'float32',
spike_channel_index: int = 0):
wf_gain = self.header['spike_channels']['wf_gain'][spike_channel_index]
wf_offset = self.header['spike_channels']['wf_offset'][spike_channel_index]

Expand All @@ -672,11 +672,11 @@ def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype
return float_waveforms

# event and epoch zone
def event_count(self, block_index:int =0, seg_index: int =0, event_channel_index: int =0):
def event_count(self, block_index: int = 0, seg_index: int = 0, event_channel_index: int = 0):
return self._event_count(block_index, seg_index, event_channel_index)

def get_event_timestamps(self, block_index: int =0, seg_index: int =0, event_channel_index: int =0,
t_start: float|None =None, t_stop: float|None=None):
def get_event_timestamps(self, block_index: int = 0, seg_index: int = 0, event_channel_index: int = 0,
t_start: float | None = None, t_stop: float | None = None):
"""
The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
Sometimes it is the index on the signal but not always.
Expand All @@ -694,21 +694,21 @@ def get_event_timestamps(self, block_index: int =0, seg_index: int =0, event_cha
block_index, seg_index, event_channel_index, t_start, t_stop)
return timestamp, durations, labels

def rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype='float64',
event_channel_index:int =0):
def rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype= 'float64',
event_channel_index:int = 0):
"""
Rescale event timestamps to seconds.
"""
return self._rescale_event_timestamp(event_timestamps, dtype, event_channel_index)

def rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype ='float64',
event_channel_index:int =0):
def rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype = 'float64',
event_channel_index: int = 0):
"""
Rescale epoch raw duration to seconds.
"""
return self._rescale_epoch_duration(raw_duration, dtype, event_channel_index)

def setup_cache(self, cache_path: 'home'|'same_as_resource', **init_kargs):
def setup_cache(self, cache_path: 'home' | 'same_as_resource', **init_kargs):
try:
import joblib
except ImportError:
Expand Down Expand Up @@ -780,7 +780,7 @@ def _source_name(self):
def _segment_t_start(self, block_index: int, seg_index: int):
raise (NotImplementedError)

def _segment_t_stop(self, block_index:int , seg_index: int):
def _segment_t_stop(self, block_index: int , seg_index: int):
raise (NotImplementedError)

###
Expand All @@ -801,8 +801,8 @@ def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: in
"""
raise (NotImplementedError)

def _get_analogsignal_chunk(self, block_index: int, seg_index: int, i_start: int|None, i_stop: int|None,
stream_index: int, channel_indexes: list[int]|None):
def _get_analogsignal_chunk(self, block_index: int, seg_index: int, i_start: int | None, i_stop: int | None,
stream_index: int, channel_indexes: list[int] | None):
"""
Return the samples from a set of AnalogSignals indexed
by stream_index and channel_indexes (local index inner stream).
Expand All @@ -820,7 +820,7 @@ def _spike_count(self, block_index: int, seg_index: int, spike_channel_index: in
raise (NotImplementedError)

def _get_spike_timestamps(self, block_index: int, seg_index: int,
spike_channel_index: int, t_start: float|None, t_stop: float|None):
spike_channel_index: int, t_start: float | None, t_stop: float | None):
raise (NotImplementedError)

def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype):
Expand All @@ -829,15 +829,15 @@ def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype
###
# spike waveforms zone
def _get_spike_raw_waveforms(self, block_index: int, seg_index: int,
spike_channel_index: int, t_start: float|None, t_stop: float|None):
spike_channel_index: int, t_start: float | None, t_stop: float | None):
raise (NotImplementedError)

###
# event and epoch zone
def _event_count(self, block_index: int, seg_index: int, event_channel_index: int):
raise (NotImplementedError)

def _get_event_timestamps(self, block_index: int, seg_index: int, event_channel_index: int, t_start: float|None, t_stop: float|None):
def _get_event_timestamps(self, block_index: int, seg_index: int, event_channel_index: int, t_start: float | None, t_stop: float | None):
raise (NotImplementedError)

def _rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype):
Expand All @@ -847,7 +847,7 @@ def _rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype):
raise (NotImplementedError)


def pprint_vector(vector, lim: int=8):
def pprint_vector(vector, lim: int = 8):
vector = np.asarray(vector)
assert vector.ndim == 1
if len(vector) > lim:
Expand Down

0 comments on commit bbccb4a

Please sign in to comment.