Skip to content

Commit

Permalink
PEP 8 compliance
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Nov 26, 2023
1 parent 29a6380 commit 344d36c
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 32 deletions.
4 changes: 2 additions & 2 deletions neo/io/baseio.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class BaseIO:

mode = 'file' # or 'fake' or 'dir' or 'database'

def __init__(self, filename: str | Path =None, **kargs):
def __init__(self, filename: str | Path = None, **kargs):
self.filename = str(filename)
# create a logger for the IO class
fullname = self.__class__.__module__ + '.' + self.__class__.__name__
Expand All @@ -113,7 +113,7 @@ def __init__(self, filename: str | Path =None, **kargs):
corelogger.addHandler(logging_handler)

######## General read/write methods #######################
def read(self, lazy:bool=False, **kargs):
def read(self, lazy: bool = False, **kargs):
"""
Return all data from the file as a list of Blocks
"""
Expand Down
93 changes: 71 additions & 22 deletions neo/rawio/baserawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class BaseRawIO:

rawmode = None # one key from possible_raw_modes

def __init__(self, use_cache: bool = False, cache_path = 'same_as_resource', **kargs):
def __init__(self, use_cache: bool = False, cache_path: str = 'same_as_resource', **kargs):
"""
:TODO: Why multi-file would have a single filename is confusing here - shouldn't
the name of this argument be filenames_list or filenames_base or similar?
Expand Down Expand Up @@ -474,7 +474,11 @@ def channel_id_to_index(self, stream_index: int, channel_ids: list[str]):
channel_indexes = np.array([chan_ids.index(chan_id) for chan_id in channel_ids])
return channel_indexes

def _get_channel_indexes(self, stream_index: int, channel_indexes: list[int] | None, channel_names: list[str] | None, channel_ids: list[str] | None):
def _get_channel_indexes(self,
stream_index: int,
channel_indexes: list[int] | None,
channel_names: list[str] | None,
channel_ids: list[str] | None):
"""
Select channel_indexes for a stream based on channel_indexes/channel_names/channel_ids
depending which is not None.
Expand Down Expand Up @@ -505,7 +509,10 @@ def get_signal_size(self, block_index: int, seg_index: int, stream_index: int |
stream_index = self._get_stream_index_from_arg(stream_index)
return self._get_signal_size(block_index, seg_index, stream_index)

def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int | None = None):
def get_signal_t_start(self,
block_index: int,
seg_index: int,
stream_index: int | None = None):
"""
Retrieve the t_start of a single section of the channels in a stream.
:param block_index:
Expand All @@ -529,9 +536,16 @@ def get_signal_sampling_rate(self, stream_index: int | None = None):
sr = signal_channels[0]['sampling_rate']
return float(sr)

def get_analogsignal_chunk(self, block_index: int = 0, seg_index: int = 0, i_start: int | None = None, i_stop: int | None = None,
stream_index: int | None = None, channel_indexes: list[int] | None = None, channel_names: list[str] | None = None,
channel_ids: list[str] | None = None, prefer_slice: bool = False):
def get_analogsignal_chunk(self,
block_index: int = 0,
seg_index: int = 0,
i_start: int | None = None,
i_stop: int | None = None,
stream_index: int | None = None,
channel_indexes: list[int] | None = None,
channel_names: list[str] | None = None,
channel_ids: list[str] | None = None,
prefer_slice: bool = False):
"""
Return a chunk of raw signal as a Numpy array. columns correspond to samples from a
section of a single channel of recording. The channels are chosen either by channel_names,
Expand Down Expand Up @@ -588,8 +602,13 @@ def get_analogsignal_chunk(self, block_index: int = 0, seg_index: int = 0, i_sta

return raw_chunk

def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype = 'float32', stream_index: int | None = None,
channel_indexes: list[int] | None = None, channel_names: list[str] | None = None, channel_ids: list[str] | None = None):
def rescale_signal_raw_to_float(self,
raw_signal: np.ndarray,
dtype: np.dtype = 'float32',
stream_index: int | None = None,
channel_indexes: list[int] | None = None,
channel_names: list[str] | None = None,
channel_ids: list[str] | None = None):
"""
Rescale a chunk of raw signals which are provided as a Numpy array. These are normally
returned by a call to get_analogsignal_chunk. The channels are specified either by
Expand Down Expand Up @@ -631,8 +650,12 @@ def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype =
def spike_count(self, block_index: int = 0, seg_index: int = 0, spike_channel_index: int = 0):
return self._spike_count(block_index, seg_index, spike_channel_index)

def get_spike_timestamps(self, block_index:int = 0, seg_index: int = 0, spike_channel_index: int = 0,
t_start: float | None = None, t_stop: float | None = None):
def get_spike_timestamps(self,
block_index: int = 0,
seg_index: int = 0,
spike_channel_index: int = 0,
t_start: float | None = None,
t_stop: float | None = None):
"""
The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
Sometimes it is the index on the signal but not always.
Expand All @@ -651,8 +674,12 @@ def rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype
return self._rescale_spike_timestamp(spike_timestamps, dtype)

# spiketrain waveform zone
def get_spike_raw_waveforms(self, block_index: int = 0, seg_index: int = 0, spike_channel_index: int = 0,
t_start: float | None = None, t_stop: float | None = None):
def get_spike_raw_waveforms(self,
block_index: int = 0,
seg_index: int = 0,
spike_channel_index: int = 0,
t_start: float | None = None,
t_stop: float | None = None):
wf = self._get_spike_raw_waveforms(block_index, seg_index,
spike_channel_index, t_start, t_stop)
return wf
Expand All @@ -675,8 +702,12 @@ def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype
def event_count(self, block_index: int = 0, seg_index: int = 0, event_channel_index: int = 0):
return self._event_count(block_index, seg_index, event_channel_index)

def get_event_timestamps(self, block_index: int = 0, seg_index: int = 0, event_channel_index: int = 0,
t_start: float | None = None, t_stop: float | None = None):
def get_event_timestamps(self,
block_index: int = 0,
seg_index: int = 0,
event_channel_index: int = 0,
t_start: float | None = None,
t_stop: float | None = None):
"""
The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
Sometimes it is the index on the signal but not always.
Expand Down Expand Up @@ -780,7 +811,7 @@ def _source_name(self):
def _segment_t_start(self, block_index: int, seg_index: int):
raise (NotImplementedError)

def _segment_t_stop(self, block_index: int , seg_index: int):
def _segment_t_stop(self, block_index: int, seg_index: int):
raise (NotImplementedError)

###
Expand All @@ -801,8 +832,13 @@ def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: in
"""
raise (NotImplementedError)

def _get_analogsignal_chunk(self, block_index: int, seg_index: int, i_start: int | None, i_stop: int | None,
stream_index: int, channel_indexes: list[int] | None):
def _get_analogsignal_chunk(self,
block_index: int,
seg_index: int,
i_start: int | None,
i_stop: int | None,
stream_index: int,
channel_indexes: list[int] | None):
"""
Return the samples from a set of AnalogSignals indexed
by stream_index and channel_indexes (local index inner stream).
Expand All @@ -819,25 +855,38 @@ def _get_analogsignal_chunk(self, block_index: int, seg_index: int, i_start: int
def _spike_count(self, block_index: int, seg_index: int, spike_channel_index: int):
raise (NotImplementedError)

def _get_spike_timestamps(self, block_index: int, seg_index: int,
spike_channel_index: int, t_start: float | None, t_stop: float | None):
def _get_spike_timestamps(self,
block_index: int,
seg_index: int,
spike_channel_index: int,
t_start: float | None,
t_stop: float | None):
raise (NotImplementedError)

def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype):
raise (NotImplementedError)

###
# spike waveforms zone
def _get_spike_raw_waveforms(self, block_index: int, seg_index: int,
spike_channel_index: int, t_start: float | None, t_stop: float | None):
def _get_spike_raw_waveforms(self,
block_index: int,
seg_index: int,
spike_channel_index: int,
t_start: float | None,
t_stop: float | None):
raise (NotImplementedError)

###
# event and epoch zone
def _event_count(self, block_index: int, seg_index: int, event_channel_index: int):
raise (NotImplementedError)

def _get_event_timestamps(self, block_index: int, seg_index: int, event_channel_index: int, t_start: float | None, t_stop: float | None):
def _get_event_timestamps(self,
block_index: int,
seg_index: int,
event_channel_index: int,
t_start: float | None,
t_stop: float | None):
raise (NotImplementedError)

def _rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype):
Expand Down
35 changes: 27 additions & 8 deletions neo/rawio/examplerawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ExampleRawIO(BaseRawIO):
extensions = ['fake']
rawmode = 'one-file'

def __init__(self, filename: str|Path =''):
def __init__(self, filename: str | Path = ''):
BaseRawIO.__init__(self)
# note that this filename is ued in self._source_name
self.filename = filename
Expand Down Expand Up @@ -267,8 +267,13 @@ def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: in
# this is not always the case
return self._segment_t_start(block_index, seg_index)

def _get_analogsignal_chunk(self, block_index:int, seg_index:int, i_start: int | None, i_stop: int | None,
stream_index: int, channel_indexes: np.ndarray|list|slice|None):
def _get_analogsignal_chunk(self,
block_index: int,
seg_index: int,
i_start: int | None,
i_stop: int | None,
stream_index: int,
channel_indexes: np.ndarray | list | slice | None):
# this must return a signal chunk in a signal stream
# limited with i_start/i_stop (can be None)
# channel_indexes can be None (=all channel in the stream) or a list or numpy.array
Expand Down Expand Up @@ -308,14 +313,19 @@ def _get_analogsignal_chunk(self, block_index:int, seg_index:int, i_start: int |
raw_signals = np.zeros((i_stop - i_start, nb_chan), dtype='int16')
return raw_signals

def _spike_count(self, block_index:int, seg_index:int, spike_channel_index:int):
def _spike_count(self, block_index: int, seg_index: int, spike_channel_index: int):
# Must return the nb of spikes for given (block_index, seg_index, spike_channel_index)
# we are lucky: our units have all the same nb of spikes!!
# it is not always the case
nb_spikes = 20
return nb_spikes

def _get_spike_timestamps(self, block_index: int, seg_index: int, spike_channel_index: int, t_start: float|None, t_stop: float|None):
def _get_spike_timestamps(self,
block_index: int,
seg_index: int,
spike_channel_index: int,
t_start: float | None,
t_stop: float | None):
# In our IO, timestamp are internally coded 'int64' and they
# represent the index of the signals 10kHz
# we are lucky: spikes have the same discharge in all segments!!
Expand Down Expand Up @@ -343,8 +353,12 @@ def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype
spike_times /= 10000. # because 10kHz
return spike_times

def _get_spike_raw_waveforms(self, block_index: int, seg_index: int, spike_channel_index: int,
t_start: float|None, t_stop: float|None):
def _get_spike_raw_waveforms(self,
block_index: int,
seg_index: int,
spike_channel_index: int,
t_start: float | None,
t_stop: float | None):
# this must return a 3D numpy array (nb_spike, nb_channel, nb_sample)
# in the original dtype
# this must be as fast as possible.
Expand Down Expand Up @@ -380,7 +394,12 @@ def _event_count(self, block_index: int, seg_index: int, event_channel_index: in
# epoch channel
return 10

def _get_event_timestamps(self, block_index: int, seg_index: int, event_channel_index: int, t_start: float|None, t_stop: float|None):
def _get_event_timestamps(self,
block_index: int,
seg_index: int,
event_channel_index: int,
t_start: float | None,
t_stop: float | None):
# the main difference between spike channel and event channel
# is that for event channels we have 3D numpy array (timestamp, durations, labels) where
# durations must be None for 'event'
Expand Down

0 comments on commit 344d36c

Please sign in to comment.