diff --git a/neo/io/baseio.py b/neo/io/baseio.py index cc2b35c26..2ce5b07bf 100644 --- a/neo/io/baseio.py +++ b/neo/io/baseio.py @@ -98,7 +98,7 @@ class BaseIO: mode = 'file' # or 'fake' or 'dir' or 'database' - def __init__(self, filename: str | Path =None, **kargs): + def __init__(self, filename: str | Path = None, **kargs): self.filename = str(filename) # create a logger for the IO class fullname = self.__class__.__module__ + '.' + self.__class__.__name__ @@ -113,7 +113,7 @@ def __init__(self, filename: str | Path =None, **kargs): corelogger.addHandler(logging_handler) ######## General read/write methods ####################### - def read(self, lazy:bool=False, **kargs): + def read(self, lazy: bool = False, **kargs): """ Return all data from the file as a list of Blocks """ diff --git a/neo/rawio/baserawio.py b/neo/rawio/baserawio.py index 1124e795b..e26c6ca6b 100644 --- a/neo/rawio/baserawio.py +++ b/neo/rawio/baserawio.py @@ -134,7 +134,7 @@ class BaseRawIO: rawmode = None # one key from possible_raw_modes - def __init__(self, use_cache: bool = False, cache_path = 'same_as_resource', **kargs): + def __init__(self, use_cache: bool = False, cache_path: str = 'same_as_resource', **kargs): """ :TODO: Why multi-file would have a single filename is confusing here - shouldn't the name of this argument be filenames_list or filenames_base or similar? @@ -474,7 +474,11 @@ def channel_id_to_index(self, stream_index: int, channel_ids: list[str]): channel_indexes = np.array([chan_ids.index(chan_id) for chan_id in channel_ids]) return channel_indexes - def _get_channel_indexes(self, stream_index: int, channel_indexes: list[int] | None, channel_names: list[str] | None, channel_ids: list[str] | None): + def _get_channel_indexes(self, + stream_index: int, + channel_indexes: list[int] | None, + channel_names: list[str] | None, + channel_ids: list[str] | None): """ Select channel_indexes for a stream based on channel_indexes/channel_names/channel_ids depending which is not None. @@ -505,7 +509,10 @@ def get_signal_size(self, block_index: int, seg_index: int, stream_index: int | stream_index = self._get_stream_index_from_arg(stream_index) return self._get_signal_size(block_index, seg_index, stream_index) - def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int | None = None): + def get_signal_t_start(self, + block_index: int, + seg_index: int, + stream_index: int | None = None): """ Retrieve the t_start of a single section of the channels in a stream. :param block_index: @@ -529,9 +536,16 @@ def get_signal_sampling_rate(self, stream_index: int | None = None): sr = signal_channels[0]['sampling_rate'] return float(sr) - def get_analogsignal_chunk(self, block_index: int = 0, seg_index: int = 0, i_start: int | None = None, i_stop: int | None = None, - stream_index: int | None = None, channel_indexes: list[int] | None = None, channel_names: list[str] | None = None, - channel_ids: list[str] | None = None, prefer_slice: bool = False): + def get_analogsignal_chunk(self, + block_index: int = 0, + seg_index: int = 0, + i_start: int | None = None, + i_stop: int | None = None, + stream_index: int | None = None, + channel_indexes: list[int] | None = None, + channel_names: list[str] | None = None, + channel_ids: list[str] | None = None, + prefer_slice: bool = False): """ Return a chunk of raw signal as a Numpy array. columns correspond to samples from a section of a single channel of recording. The channels are chosen either by channel_names, @@ -588,8 +602,13 @@ def get_analogsignal_chunk(self, block_index: int = 0, seg_index: int = 0, i_sta return raw_chunk - def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype = 'float32', stream_index: int | None = None, - channel_indexes: list[int] | None = None, channel_names: list[str] | None = None, channel_ids: list[str] | None = None): + def rescale_signal_raw_to_float(self, + raw_signal: np.ndarray, + dtype: np.dtype = 'float32', + stream_index: int | None = None, + channel_indexes: list[int] | None = None, + channel_names: list[str] | None = None, + channel_ids: list[str] | None = None): """ Rescale a chunk of raw signals which are provided as a Numpy array. These are normally returned by a call to get_analogsignal_chunk. The channels are specified either by @@ -631,8 +650,12 @@ def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype = def spike_count(self, block_index: int = 0, seg_index: int = 0, spike_channel_index: int = 0): return self._spike_count(block_index, seg_index, spike_channel_index) - def get_spike_timestamps(self, block_index:int = 0, seg_index: int = 0, spike_channel_index: int = 0, - t_start: float | None = None, t_stop: float | None = None): + def get_spike_timestamps(self, + block_index: int = 0, + seg_index: int = 0, + spike_channel_index: int = 0, + t_start: float | None = None, + t_stop: float | None = None): """ The timestamp datatype is as close to the format itself. Sometimes float/int32/int64. Sometimes it is the index on the signal but not always. @@ -651,8 +674,12 @@ def rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype return self._rescale_spike_timestamp(spike_timestamps, dtype) # spiketrain waveform zone - def get_spike_raw_waveforms(self, block_index: int = 0, seg_index: int = 0, spike_channel_index: int = 0, - t_start: float | None = None, t_stop: float | None = None): + def get_spike_raw_waveforms(self, + block_index: int = 0, + seg_index: int = 0, + spike_channel_index: int = 0, + t_start: float | None = None, + t_stop: float | None = None): wf = self._get_spike_raw_waveforms(block_index, seg_index, spike_channel_index, t_start, t_stop) return wf @@ -675,8 +702,12 @@ def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype def event_count(self, block_index: int = 0, seg_index: int = 0, event_channel_index: int = 0): return self._event_count(block_index, seg_index, event_channel_index) - def get_event_timestamps(self, block_index: int = 0, seg_index: int = 0, event_channel_index: int = 0, - t_start: float | None = None, t_stop: float | None = None): + def get_event_timestamps(self, + block_index: int = 0, + seg_index: int = 0, + event_channel_index: int = 0, + t_start: float | None = None, + t_stop: float | None = None): """ The timestamp datatype is as close to the format itself. Sometimes float/int32/int64. Sometimes it is the index on the signal but not always. @@ -780,7 +811,7 @@ def _source_name(self): def _segment_t_start(self, block_index: int, seg_index: int): raise (NotImplementedError) - def _segment_t_stop(self, block_index: int , seg_index: int): + def _segment_t_stop(self, block_index: int, seg_index: int): raise (NotImplementedError) ### @@ -801,8 +832,13 @@ def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: in """ raise (NotImplementedError) - def _get_analogsignal_chunk(self, block_index: int, seg_index: int, i_start: int | None, i_stop: int | None, - stream_index: int, channel_indexes: list[int] | None): + def _get_analogsignal_chunk(self, + block_index: int, + seg_index: int, + i_start: int | None, + i_stop: int | None, + stream_index: int, + channel_indexes: list[int] | None): """ Return the samples from a set of AnalogSignals indexed by stream_index and channel_indexes (local index inner stream). @@ -819,8 +855,12 @@ def _get_analogsignal_chunk(self, block_index: int, seg_index: int, i_start: int def _spike_count(self, block_index: int, seg_index: int, spike_channel_index: int): raise (NotImplementedError) - def _get_spike_timestamps(self, block_index: int, seg_index: int, - spike_channel_index: int, t_start: float | None, t_stop: float | None): + def _get_spike_timestamps(self, + block_index: int, + seg_index: int, + spike_channel_index: int, + t_start: float | None, + t_stop: float | None): raise (NotImplementedError) def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype): @@ -828,8 +868,12 @@ def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype ### # spike waveforms zone - def _get_spike_raw_waveforms(self, block_index: int, seg_index: int, - spike_channel_index: int, t_start: float | None, t_stop: float | None): + def _get_spike_raw_waveforms(self, + block_index: int, + seg_index: int, + spike_channel_index: int, + t_start: float | None, + t_stop: float | None): raise (NotImplementedError) ### @@ -837,7 +881,12 @@ def _get_spike_raw_waveforms(self, block_index: int, seg_index: int, def _event_count(self, block_index: int, seg_index: int, event_channel_index: int): raise (NotImplementedError) - def _get_event_timestamps(self, block_index: int, seg_index: int, event_channel_index: int, t_start: float | None, t_stop: float | None): + def _get_event_timestamps(self, + block_index: int, + seg_index: int, + event_channel_index: int, + t_start: float | None, + t_stop: float | None): raise (NotImplementedError) def _rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype): diff --git a/neo/rawio/examplerawio.py b/neo/rawio/examplerawio.py index d7c288a55..25dabb27a 100644 --- a/neo/rawio/examplerawio.py +++ b/neo/rawio/examplerawio.py @@ -85,7 +85,7 @@ class ExampleRawIO(BaseRawIO): extensions = ['fake'] rawmode = 'one-file' - def __init__(self, filename: str|Path =''): + def __init__(self, filename: str | Path = ''): BaseRawIO.__init__(self) # note that this filename is ued in self._source_name self.filename = filename @@ -267,8 +267,13 @@ def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: in # this is not always the case return self._segment_t_start(block_index, seg_index) - def _get_analogsignal_chunk(self, block_index:int, seg_index:int, i_start: int | None, i_stop: int | None, - stream_index: int, channel_indexes: np.ndarray|list|slice|None): + def _get_analogsignal_chunk(self, + block_index: int, + seg_index: int, + i_start: int | None, + i_stop: int | None, + stream_index: int, + channel_indexes: np.ndarray | list | slice | None): # this must return a signal chunk in a signal stream # limited with i_start/i_stop (can be None) # channel_indexes can be None (=all channel in the stream) or a list or numpy.array @@ -308,14 +313,19 @@ def _get_analogsignal_chunk(self, block_index:int, seg_index:int, i_start: int | raw_signals = np.zeros((i_stop - i_start, nb_chan), dtype='int16') return raw_signals - def _spike_count(self, block_index:int, seg_index:int, spike_channel_index:int): + def _spike_count(self, block_index: int, seg_index: int, spike_channel_index: int): # Must return the nb of spikes for given (block_index, seg_index, spike_channel_index) # we are lucky: our units have all the same nb of spikes!! # it is not always the case nb_spikes = 20 return nb_spikes - def _get_spike_timestamps(self, block_index: int, seg_index: int, spike_channel_index: int, t_start: float|None, t_stop: float|None): + def _get_spike_timestamps(self, + block_index: int, + seg_index: int, + spike_channel_index: int, + t_start: float | None, + t_stop: float | None): # In our IO, timestamp are internally coded 'int64' and they # represent the index of the signals 10kHz # we are lucky: spikes have the same discharge in all segments!! @@ -343,8 +353,12 @@ def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype spike_times /= 10000. # because 10kHz return spike_times - def _get_spike_raw_waveforms(self, block_index: int, seg_index: int, spike_channel_index: int, - t_start: float|None, t_stop: float|None): + def _get_spike_raw_waveforms(self, + block_index: int, + seg_index: int, + spike_channel_index: int, + t_start: float | None, + t_stop: float | None): # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample) # in the original dtype # this must be as fast as possible. @@ -380,7 +394,12 @@ def _event_count(self, block_index: int, seg_index: int, event_channel_index: in # epoch channel return 10 - def _get_event_timestamps(self, block_index: int, seg_index: int, event_channel_index: int, t_start: float|None, t_stop: float|None): + def _get_event_timestamps(self, + block_index: int, + seg_index: int, + event_channel_index: int, + t_start: float | None, + t_stop: float | None): # the main difference between spike channel and event channel # is that for event channels we have 3D numpy array (timestamp, durations, labels) where # durations must be None for 'event'