From bbccb4aec1edf1091a39dd82a9a1c1bdd2182617 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Sat, 25 Nov 2023 16:01:42 -0500 Subject: [PATCH] Pep8 for baserawio I --- neo/rawio/baserawio.py | 68 +++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/neo/rawio/baserawio.py b/neo/rawio/baserawio.py index bcd7a8779..1124e795b 100644 --- a/neo/rawio/baserawio.py +++ b/neo/rawio/baserawio.py @@ -134,7 +134,7 @@ class BaseRawIO: rawmode = None # one key from possible_raw_modes - def __init__(self, use_cache: bool=False, cache_path='same_as_resource', **kargs): + def __init__(self, use_cache: bool = False, cache_path = 'same_as_resource', **kargs): """ :TODO: Why multi-file would have a single filename is confusing here - shouldn't the name of this argument be filenames_list or filenames_base or similar? @@ -474,7 +474,7 @@ def channel_id_to_index(self, stream_index: int, channel_ids: list[str]): channel_indexes = np.array([chan_ids.index(chan_id) for chan_id in channel_ids]) return channel_indexes - def _get_channel_indexes(self, stream_index:int, channel_indexes: list[int]|None, channel_names: list[str]|None, channel_ids: list[str]|None): + def _get_channel_indexes(self, stream_index: int, channel_indexes: list[int] | None, channel_names: list[str] | None, channel_ids: list[str] | None): """ Select channel_indexes for a stream based on channel_indexes/channel_names/channel_ids depending which is not None. @@ -485,7 +485,7 @@ def _get_channel_indexes(self, stream_index:int, channel_indexes: list[int]|None channel_indexes = self.channel_id_to_index(stream_index, channel_ids) return channel_indexes - def _get_stream_index_from_arg(self, stream_index_arg): + def _get_stream_index_from_arg(self, stream_index_arg: int | None): if stream_index_arg is None: assert self.header['signal_streams'].size == 1 stream_index = 0 @@ -494,7 +494,7 @@ def _get_stream_index_from_arg(self, stream_index_arg): stream_index = stream_index_arg return stream_index - def get_signal_size(self, block_index: int, seg_index: int, stream_index: int|None=None): + def get_signal_size(self, block_index: int, seg_index: int, stream_index: int | None = None): """ Retrieve the length of a single section of the channels in a stream. :param block_index: @@ -505,7 +505,7 @@ def get_signal_size(self, block_index: int, seg_index: int, stream_index: int|No stream_index = self._get_stream_index_from_arg(stream_index) return self._get_signal_size(block_index, seg_index, stream_index) - def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int|None=None): + def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int | None = None): """ Retrieve the t_start of a single section of the channels in a stream. :param block_index: @@ -516,7 +516,7 @@ def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int stream_index = self._get_stream_index_from_arg(stream_index) return self._get_signal_t_start(block_index, seg_index, stream_index) - def get_signal_sampling_rate(self, stream_index: int| None=None): + def get_signal_sampling_rate(self, stream_index: int | None = None): """ Retrieve sampling rate for a stream and all channels in that stream. :param stream_index: @@ -529,9 +529,9 @@ def get_signal_sampling_rate(self, stream_index: int| None=None): sr = signal_channels[0]['sampling_rate'] return float(sr) - def get_analogsignal_chunk(self, block_index: int =0, seg_index: int =0, i_start: int|None =None, i_stop: int|None =None, - stream_index: int|None =None, channel_indexes: list[int]|None =None, channel_names: list[str]|None =None, - channel_ids: list[str]|None=None, prefer_slice:bool=False): + def get_analogsignal_chunk(self, block_index: int = 0, seg_index: int = 0, i_start: int | None = None, i_stop: int | None = None, + stream_index: int | None = None, channel_indexes: list[int] | None = None, channel_names: list[str] | None = None, + channel_ids: list[str] | None = None, prefer_slice: bool = False): """ Return a chunk of raw signal as a Numpy array. columns correspond to samples from a section of a single channel of recording. The channels are chosen either by channel_names, @@ -588,8 +588,8 @@ def get_analogsignal_chunk(self, block_index: int =0, seg_index: int =0, i_start return raw_chunk - def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype='float32', stream_index: int|None=None, - channel_indexes: list[int]|None=None, channel_names: list[str]|None=None, channel_ids: list[str]|None=None): + def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype = 'float32', stream_index: int | None = None, + channel_indexes: list[int] | None = None, channel_names: list[str] | None = None, channel_ids: list[str] | None = None): """ Rescale a chunk of raw signals which are provided as a Numpy array. These are normally returned by a call to get_analogsignal_chunk. The channels are specified either by @@ -628,11 +628,11 @@ def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype='f return float_signal # spiketrain and unit zone - def spike_count(self, block_index: int=0, seg_index: int=0, spike_channel_index:int =0): + def spike_count(self, block_index: int = 0, seg_index: int = 0, spike_channel_index: int = 0): return self._spike_count(block_index, seg_index, spike_channel_index) - def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0, - t_start=None, t_stop=None): + def get_spike_timestamps(self, block_index:int = 0, seg_index: int = 0, spike_channel_index: int = 0, + t_start: float | None = None, t_stop: float | None = None): """ The timestamp datatype is as close to the format itself. Sometimes float/int32/int64. Sometimes it is the index on the signal but not always. @@ -644,21 +644,21 @@ def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0 spike_channel_index, t_start, t_stop) return timestamp - def rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype='float64'): + def rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype = 'float64'): """ Rescale spike timestamps to seconds. """ return self._rescale_spike_timestamp(spike_timestamps, dtype) # spiketrain waveform zone - def get_spike_raw_waveforms(self, block_index: int=0, seg_index: int=0, spike_channel_index: int=0, - t_start: float|None =None, t_stop: float|None=None): + def get_spike_raw_waveforms(self, block_index: int = 0, seg_index: int = 0, spike_channel_index: int = 0, + t_start: float | None = None, t_stop: float | None = None): wf = self._get_spike_raw_waveforms(block_index, seg_index, spike_channel_index, t_start, t_stop) return wf - def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype ='float32', - spike_channel_index: int =0): + def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype = 'float32', + spike_channel_index: int = 0): wf_gain = self.header['spike_channels']['wf_gain'][spike_channel_index] wf_offset = self.header['spike_channels']['wf_offset'][spike_channel_index] @@ -672,11 +672,11 @@ def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype return float_waveforms # event and epoch zone - def event_count(self, block_index:int =0, seg_index: int =0, event_channel_index: int =0): + def event_count(self, block_index: int = 0, seg_index: int = 0, event_channel_index: int = 0): return self._event_count(block_index, seg_index, event_channel_index) - def get_event_timestamps(self, block_index: int =0, seg_index: int =0, event_channel_index: int =0, - t_start: float|None =None, t_stop: float|None=None): + def get_event_timestamps(self, block_index: int = 0, seg_index: int = 0, event_channel_index: int = 0, + t_start: float | None = None, t_stop: float | None = None): """ The timestamp datatype is as close to the format itself. Sometimes float/int32/int64. Sometimes it is the index on the signal but not always. @@ -694,21 +694,21 @@ def get_event_timestamps(self, block_index: int =0, seg_index: int =0, event_cha block_index, seg_index, event_channel_index, t_start, t_stop) return timestamp, durations, labels - def rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype='float64', - event_channel_index:int =0): + def rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype= 'float64', + event_channel_index:int = 0): """ Rescale event timestamps to seconds. """ return self._rescale_event_timestamp(event_timestamps, dtype, event_channel_index) - def rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype ='float64', - event_channel_index:int =0): + def rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype = 'float64', + event_channel_index: int = 0): """ Rescale epoch raw duration to seconds. """ return self._rescale_epoch_duration(raw_duration, dtype, event_channel_index) - def setup_cache(self, cache_path: 'home'|'same_as_resource', **init_kargs): + def setup_cache(self, cache_path: 'home' | 'same_as_resource', **init_kargs): try: import joblib except ImportError: @@ -780,7 +780,7 @@ def _source_name(self): def _segment_t_start(self, block_index: int, seg_index: int): raise (NotImplementedError) - def _segment_t_stop(self, block_index:int , seg_index: int): + def _segment_t_stop(self, block_index: int , seg_index: int): raise (NotImplementedError) ### @@ -801,8 +801,8 @@ def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: in """ raise (NotImplementedError) - def _get_analogsignal_chunk(self, block_index: int, seg_index: int, i_start: int|None, i_stop: int|None, - stream_index: int, channel_indexes: list[int]|None): + def _get_analogsignal_chunk(self, block_index: int, seg_index: int, i_start: int | None, i_stop: int | None, + stream_index: int, channel_indexes: list[int] | None): """ Return the samples from a set of AnalogSignals indexed by stream_index and channel_indexes (local index inner stream). @@ -820,7 +820,7 @@ def _spike_count(self, block_index: int, seg_index: int, spike_channel_index: in raise (NotImplementedError) def _get_spike_timestamps(self, block_index: int, seg_index: int, - spike_channel_index: int, t_start: float|None, t_stop: float|None): + spike_channel_index: int, t_start: float | None, t_stop: float | None): raise (NotImplementedError) def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype): @@ -829,7 +829,7 @@ def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype ### # spike waveforms zone def _get_spike_raw_waveforms(self, block_index: int, seg_index: int, - spike_channel_index: int, t_start: float|None, t_stop: float|None): + spike_channel_index: int, t_start: float | None, t_stop: float | None): raise (NotImplementedError) ### @@ -837,7 +837,7 @@ def _get_spike_raw_waveforms(self, block_index: int, seg_index: int, def _event_count(self, block_index: int, seg_index: int, event_channel_index: int): raise (NotImplementedError) - def _get_event_timestamps(self, block_index: int, seg_index: int, event_channel_index: int, t_start: float|None, t_stop: float|None): + def _get_event_timestamps(self, block_index: int, seg_index: int, event_channel_index: int, t_start: float | None, t_stop: float | None): raise (NotImplementedError) def _rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype): @@ -847,7 +847,7 @@ def _rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype): raise (NotImplementedError) -def pprint_vector(vector, lim: int=8): +def pprint_vector(vector, lim: int = 8): vector = np.asarray(vector) assert vector.ndim == 1 if len(vector) > lim: