diff --git a/.github/workflows/publish-to-pypi-test.yml b/.github/workflows/publish-to-pypi-test.yml index 5a8140e59..71e7b9409 100644 --- a/.github/workflows/publish-to-pypi-test.yml +++ b/.github/workflows/publish-to-pypi-test.yml @@ -20,14 +20,19 @@ jobs: python -m pip install --upgrade pip pip install setuptools wheel twine build pip install . + - name: Get the tag version + id: get-version + run: | + echo ${GITHUB_REF#refs/tags/} + echo "TAG::${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT - name: Test version/tag correspondence id: version-check run: | neo_version=$(python -c "import neo; print(neo.__version__)") + tag_version=${{ steps.get-version.outputs.TAG }} echo $neo_version - TAG=${{ github.event.release.tag_name }} - echo $TAG - if [[ $TAG == $neo_version ]]; then + echo $tag_version + if [[ $tag_version == $neo_version ]]; then echo "VERSION_TAG_MATCH=true" >> $GITHUB_OUTPUT echo "Version matches tag, proceeding with release to Test PyPI" else diff --git a/neo/core/baseneo.py b/neo/core/baseneo.py index 40e736567..060a8b6c5 100644 --- a/neo/core/baseneo.py +++ b/neo/core/baseneo.py @@ -154,7 +154,11 @@ def _container_name(class_name): referenced by `block.segments`. The attribute name `segments` is obtained by calling `_container_name_plural("Segment")`. """ - return _reference_name(class_name) + 's' + if "RegionOfInterest" in class_name: + # this is a hack, pending a more principled way to handle this + return "regionsofinterest" + else: + return _reference_name(class_name) + 's' class BaseNeo: diff --git a/neo/core/block.py b/neo/core/block.py index baa6f64f4..93a0f4b1c 100644 --- a/neo/core/block.py +++ b/neo/core/block.py @@ -12,7 +12,6 @@ from neo.core.container import Container, unique_objs from neo.core.group import Group from neo.core.objectlist import ObjectList -from neo.core.regionofinterest import RegionOfInterest from neo.core.segment import Segment @@ -91,7 +90,6 @@ def __init__(self, name=None, description=None, file_origin=None, self.index = index self._segments = ObjectList(Segment, parent=self) self._groups = ObjectList(Group, parent=self) - self._regionsofinterest = ObjectList(RegionOfInterest, parent=self) segments = property( fget=lambda self: self._get_object_list("_segments"), @@ -105,12 +103,6 @@ def __init__(self, name=None, description=None, file_origin=None, doc="list of Groups contained in this block" ) - regionsofinterest = property( - fget=lambda self: self._get_object_list("_regionsofinterest"), - fset=lambda self, value: self._set_object_list("_regionsofinterest", value), - doc="list of RegionOfInterest objects contained in this block" - ) - @property def data_children_recur(self): ''' diff --git a/neo/core/group.py b/neo/core/group.py index f4a34273e..06d3304d8 100644 --- a/neo/core/group.py +++ b/neo/core/group.py @@ -18,6 +18,7 @@ from neo.core.segment import Segment from neo.core.spiketrainlist import SpikeTrainList from neo.core.view import ChannelView +from neo.core.regionofinterest import RegionOfInterest class Group(Container): @@ -49,7 +50,8 @@ class Group(Container): """ _data_child_objects = ( 'AnalogSignal', 'IrregularlySampledSignal', 'SpikeTrain', - 'Event', 'Epoch', 'ChannelView', 'ImageSequence' + 'Event', 'Epoch', 'ChannelView', 'ImageSequence', 'CircularRegionOfInterest', + 'RectangularRegionOfInterest', 'PolygonRegionOfInterest' ) _container_child_objects = ('Group',) _parent_objects = ('Block',) @@ -69,6 +71,7 @@ def __init__(self, objects=None, name=None, description=None, file_origin=None, self._epochs = ObjectList(Epoch) self._channelviews = ObjectList(ChannelView) self._imagesequences = ObjectList(ImageSequence) + self._regionsofinterest = ObjectList(RegionOfInterest) self._segments = ObjectList(Segment) # to remove? self._groups = ObjectList(Group) @@ -119,6 +122,12 @@ def __init__(self, objects=None, name=None, description=None, file_origin=None, doc="list of ImageSequences contained in this group" ) + regionsofinterest = property( + fget=lambda self: self._get_object_list("_regionsofinterest"), + fset=lambda self, value: self._set_object_list("_regionsofinterest", value), + doc="list of RegionOfInterest objects contained in this group" + ) + spiketrains = property( fget=lambda self: self._get_object_list("_spiketrains"), fset=lambda self, value: self._set_object_list("_spiketrains", value), diff --git a/neo/core/imagesequence.py b/neo/core/imagesequence.py index 2a84e8d02..3a877ba9f 100644 --- a/neo/core/imagesequence.py +++ b/neo/core/imagesequence.py @@ -97,7 +97,7 @@ class ImageSequence(BaseSignal): ) _recommended_attrs = BaseNeo._recommended_attrs - def __new__(cls, image_data, units=None, dtype=None, copy=True, t_start=0 * pq.s, + def __new__(cls, image_data, units=pq.dimensionless, dtype=None, copy=True, t_start=0 * pq.s, spatial_scale=None, frame_duration=None, sampling_rate=None, name=None, description=None, file_origin=None, **annotations): @@ -127,7 +127,7 @@ def __new__(cls, image_data, units=None, dtype=None, copy=True, t_start=0 * pq.s return obj - def __init__(self, image_data, units=None, dtype=None, copy=True, t_start=0 * pq.s, + def __init__(self, image_data, units=pq.dimensionless, dtype=None, copy=True, t_start=0 * pq.s, spatial_scale=None, frame_duration=None, sampling_rate=None, name=None, description=None, file_origin=None, **annotations): @@ -142,7 +142,7 @@ def __array_finalize__spec(self, obj): self.sampling_rate = getattr(obj, "sampling_rate", None) self.spatial_scale = getattr(obj, "spatial_scale", None) - self.units = getattr(obj, "units", None) + self.units = getattr(obj, "units", pq.dimensionless) self._t_start = getattr(obj, "_t_start", 0 * pq.s) return obj diff --git a/neo/core/regionofinterest.py b/neo/core/regionofinterest.py index cdf463653..458fb7067 100644 --- a/neo/core/regionofinterest.py +++ b/neo/core/regionofinterest.py @@ -1,11 +1,32 @@ from math import floor, ceil from neo.core.baseneo import BaseNeo +from neo.core.imagesequence import ImageSequence class RegionOfInterest(BaseNeo): """Abstract base class""" - pass + + _parent_objects = ('Group',) + _parent_attrs = ('group',) + _necessary_attrs = ( + ('obj', ('ImageSequence', ), 1), + ) + + def __init__(self, image_sequence, name=None, description=None, file_origin=None, **annotations): + super().__init__(name=name, description=description, + file_origin=file_origin, **annotations) + + if not (isinstance(image_sequence, ImageSequence) or ( + hasattr(image_sequence, "proxy_for") and issubclass(image_sequence.proxy_for, ImageSequence))): + raise ValueError("Can only take a RegionOfInterest of an ImageSequence") + self.image_sequence = image_sequence + + def resolve(self): + """ + Return a signal from within this region of the underlying ImageSequence. + """ + return self.image_sequence.signal_from_region(self) class CircularRegionOfInterest(RegionOfInterest): @@ -23,8 +44,9 @@ class CircularRegionOfInterest(RegionOfInterest): Radius of the ROI in pixels """ - def __init__(self, x, y, radius): - + def __init__(self, image_sequence, x, y, radius, name=None, description=None, + file_origin=None, **annotations): + super().__init__(image_sequence, name, description, file_origin, **annotations) self.y = y self.x = x self.radius = radius @@ -72,7 +94,9 @@ class RectangularRegionOfInterest(RegionOfInterest): Height (y-direction) of the ROI in pixels """ - def __init__(self, x, y, width, height): + def __init__(self, image_sequence, x, y, width, height, name=None, description=None, + file_origin=None, **annotations): + super().__init__(image_sequence, name, description, file_origin, **annotations) self.x = x self.y = y self.width = width @@ -115,7 +139,9 @@ class PolygonRegionOfInterest(RegionOfInterest): of the vertices of the polygon """ - def __init__(self, *vertices): + def __init__(self, image_sequence, *vertices, name=None, description=None, + file_origin=None, **annotations): + super().__init__(image_sequence, name, description, file_origin, **annotations) self.vertices = vertices def polygon_ray_casting(self, bounding_points, bounding_box_positions): diff --git a/neo/io/baseio.py b/neo/io/baseio.py index acce7de6f..2ce5b07bf 100644 --- a/neo/io/baseio.py +++ b/neo/io/baseio.py @@ -10,6 +10,8 @@ If you want a model for developing a new IO start from exampleIO. """ +from __future__ import annotations +from pathlib import Path try: from collections.abc import Sequence @@ -96,7 +98,7 @@ class BaseIO: mode = 'file' # or 'fake' or 'dir' or 'database' - def __init__(self, filename=None, **kargs): + def __init__(self, filename: str | Path = None, **kargs): self.filename = str(filename) # create a logger for the IO class fullname = self.__class__.__module__ + '.' + self.__class__.__name__ @@ -111,7 +113,7 @@ def __init__(self, filename=None, **kargs): corelogger.addHandler(logging_handler) ######## General read/write methods ####################### - def read(self, lazy=False, **kargs): + def read(self, lazy: bool = False, **kargs): """ Return all data from the file as a list of Blocks """ diff --git a/neo/rawio/baserawio.py b/neo/rawio/baserawio.py index c9646bbd8..08fb073b6 100644 --- a/neo/rawio/baserawio.py +++ b/neo/rawio/baserawio.py @@ -67,6 +67,7 @@ constructions of a RawIO for a given set of data. """ +from __future__ import annotations import logging import numpy as np @@ -133,7 +134,7 @@ class BaseRawIO: rawmode = None # one key from possible_raw_modes - def __init__(self, use_cache=False, cache_path='same_as_resource', **kargs): + def __init__(self, use_cache: bool = False, cache_path: str = 'same_as_resource', **kargs): """ :TODO: Why multi-file would have a single filename is confusing here - shouldn't the name of this argument be filenames_list or filenames_base or similar? @@ -369,7 +370,7 @@ def block_count(self): """return number of blocks""" return self.header['nb_block'] - def segment_count(self, block_index): + def segment_count(self, block_index: int): """return number of segments for a given block""" return self.header['nb_segment'][block_index] @@ -379,7 +380,7 @@ def signal_streams_count(self): """ return len(self.header['signal_streams']) - def signal_channels_count(self, stream_index): + def signal_channels_count(self, stream_index: int): """Return the number of signal channels for a given stream. This number is the same for all Blocks and Segments. """ @@ -400,7 +401,7 @@ def event_channels_count(self): """ return len(self.header['event_channels']) - def segment_t_start(self, block_index, seg_index): + def segment_t_start(self, block_index: int, seg_index: int): """Global t_start of a Segment in s. Shared by all objects except for AnalogSignal. """ @@ -445,7 +446,7 @@ def _check_stream_signal_channel_characteristics(self): self._several_channel_groups = signal_streams.size > 1 - def channel_name_to_index(self, stream_index, channel_names): + def channel_name_to_index(self, stream_index: int, channel_names: list[str]): """ Inside a stream, transform channel_names to channel_indexes. Based on self.header['signal_channels'] @@ -459,7 +460,7 @@ def channel_name_to_index(self, stream_index, channel_names): channel_indexes = np.array([chan_names.index(name) for name in channel_names]) return channel_indexes - def channel_id_to_index(self, stream_index, channel_ids): + def channel_id_to_index(self, stream_index: int, channel_ids: list[str]): """ Inside a stream, transform channel_ids to channel_indexes. Based on self.header['signal_channels'] @@ -473,7 +474,11 @@ def channel_id_to_index(self, stream_index, channel_ids): channel_indexes = np.array([chan_ids.index(chan_id) for chan_id in channel_ids]) return channel_indexes - def _get_channel_indexes(self, stream_index, channel_indexes, channel_names, channel_ids): + def _get_channel_indexes(self, + stream_index: int, + channel_indexes: list[int] | None, + channel_names: list[str] | None, + channel_ids: list[str] | None): """ Select channel_indexes for a stream based on channel_indexes/channel_names/channel_ids depending which is not None. @@ -484,7 +489,7 @@ def _get_channel_indexes(self, stream_index, channel_indexes, channel_names, cha channel_indexes = self.channel_id_to_index(stream_index, channel_ids) return channel_indexes - def _get_stream_index_from_arg(self, stream_index_arg): + def _get_stream_index_from_arg(self, stream_index_arg: int | None): if stream_index_arg is None: assert self.header['signal_streams'].size == 1 stream_index = 0 @@ -493,7 +498,7 @@ def _get_stream_index_from_arg(self, stream_index_arg): stream_index = stream_index_arg return stream_index - def get_signal_size(self, block_index, seg_index, stream_index=None): + def get_signal_size(self, block_index: int, seg_index: int, stream_index: int | None = None): """ Retrieve the length of a single section of the channels in a stream. :param block_index: @@ -504,7 +509,10 @@ def get_signal_size(self, block_index, seg_index, stream_index=None): stream_index = self._get_stream_index_from_arg(stream_index) return self._get_signal_size(block_index, seg_index, stream_index) - def get_signal_t_start(self, block_index, seg_index, stream_index=None): + def get_signal_t_start(self, + block_index: int, + seg_index: int, + stream_index: int | None = None): """ Retrieve the t_start of a single section of the channels in a stream. :param block_index: @@ -515,7 +523,7 @@ def get_signal_t_start(self, block_index, seg_index, stream_index=None): stream_index = self._get_stream_index_from_arg(stream_index) return self._get_signal_t_start(block_index, seg_index, stream_index) - def get_signal_sampling_rate(self, stream_index=None): + def get_signal_sampling_rate(self, stream_index: int | None = None): """ Retrieve sampling rate for a stream and all channels in that stream. :param stream_index: @@ -528,9 +536,16 @@ def get_signal_sampling_rate(self, stream_index=None): sr = signal_channels[0]['sampling_rate'] return float(sr) - def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_stop=None, - stream_index=None, channel_indexes=None, channel_names=None, - channel_ids=None, prefer_slice=False): + def get_analogsignal_chunk(self, + block_index: int = 0, + seg_index: int = 0, + i_start: int | None = None, + i_stop: int | None = None, + stream_index: int | None = None, + channel_indexes: list[int] | None = None, + channel_names: list[str] | None = None, + channel_ids: list[str] | None = None, + prefer_slice: bool = False): """ Return a chunk of raw signal as a Numpy array. columns correspond to samples from a section of a single channel of recording. The channels are chosen either by channel_names, @@ -587,8 +602,13 @@ def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_sto return raw_chunk - def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', stream_index=None, - channel_indexes=None, channel_names=None, channel_ids=None): + def rescale_signal_raw_to_float(self, + raw_signal: np.ndarray, + dtype: np.dtype = 'float32', + stream_index: int | None = None, + channel_indexes: list[int] | None = None, + channel_names: list[str] | None = None, + channel_ids: list[str] | None = None): """ Rescale a chunk of raw signals which are provided as a Numpy array. These are normally returned by a call to get_analogsignal_chunk. The channels are specified either by @@ -627,11 +647,15 @@ def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', stream_index= return float_signal # spiketrain and unit zone - def spike_count(self, block_index=0, seg_index=0, spike_channel_index=0): + def spike_count(self, block_index: int = 0, seg_index: int = 0, spike_channel_index: int = 0): return self._spike_count(block_index, seg_index, spike_channel_index) - def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0, - t_start=None, t_stop=None): + def get_spike_timestamps(self, + block_index: int = 0, + seg_index: int = 0, + spike_channel_index: int = 0, + t_start: float | None = None, + t_stop: float | None = None): """ The timestamp datatype is as close to the format itself. Sometimes float/int32/int64. Sometimes it is the index on the signal but not always. @@ -643,21 +667,25 @@ def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0 spike_channel_index, t_start, t_stop) return timestamp - def rescale_spike_timestamp(self, spike_timestamps, dtype='float64'): + def rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype = 'float64'): """ Rescale spike timestamps to seconds. """ return self._rescale_spike_timestamp(spike_timestamps, dtype) # spiketrain waveform zone - def get_spike_raw_waveforms(self, block_index=0, seg_index=0, spike_channel_index=0, - t_start=None, t_stop=None): + def get_spike_raw_waveforms(self, + block_index: int = 0, + seg_index: int = 0, + spike_channel_index: int = 0, + t_start: float | None = None, + t_stop: float | None = None): wf = self._get_spike_raw_waveforms(block_index, seg_index, spike_channel_index, t_start, t_stop) return wf - def rescale_waveforms_to_float(self, raw_waveforms, dtype='float32', - spike_channel_index=0): + def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype = 'float32', + spike_channel_index: int = 0): wf_gain = self.header['spike_channels']['wf_gain'][spike_channel_index] wf_offset = self.header['spike_channels']['wf_offset'][spike_channel_index] @@ -671,11 +699,15 @@ def rescale_waveforms_to_float(self, raw_waveforms, dtype='float32', return float_waveforms # event and epoch zone - def event_count(self, block_index=0, seg_index=0, event_channel_index=0): + def event_count(self, block_index: int = 0, seg_index: int = 0, event_channel_index: int = 0): return self._event_count(block_index, seg_index, event_channel_index) - def get_event_timestamps(self, block_index=0, seg_index=0, event_channel_index=0, - t_start=None, t_stop=None): + def get_event_timestamps(self, + block_index: int = 0, + seg_index: int = 0, + event_channel_index: int = 0, + t_start: float | None = None, + t_stop: float | None = None): """ The timestamp datatype is as close to the format itself. Sometimes float/int32/int64. Sometimes it is the index on the signal but not always. @@ -693,21 +725,23 @@ def get_event_timestamps(self, block_index=0, seg_index=0, event_channel_index=0 block_index, seg_index, event_channel_index, t_start, t_stop) return timestamp, durations, labels - def rescale_event_timestamp(self, event_timestamps, dtype='float64', - event_channel_index=0): + def rescale_event_timestamp(self, + event_timestamps: np.ndarray, + dtype: np.dtype = 'float64', + event_channel_index: int = 0): """ Rescale event timestamps to seconds. """ return self._rescale_event_timestamp(event_timestamps, dtype, event_channel_index) - def rescale_epoch_duration(self, raw_duration, dtype='float64', - event_channel_index=0): + def rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype = 'float64', + event_channel_index: int = 0): """ Rescale epoch raw duration to seconds. """ return self._rescale_epoch_duration(raw_duration, dtype, event_channel_index) - def setup_cache(self, cache_path, **init_kargs): + def setup_cache(self, cache_path: 'home' | 'same_as_resource', **init_kargs): try: import joblib except ImportError: @@ -735,7 +769,7 @@ def setup_cache(self, cache_path, **init_kargs): dirname = os.path.dirname(resource_name) else: assert os.path.exists(cache_path), \ - 'cache_path do not exists use "home" or "same_as_resource" to make this auto' + 'cache_path does not exists use "home" or "same_as_resource" to make this auto' # the hash of the resource (dir of file) is done with filename+datetime # TODO make something more sophisticated when rawmode='one-dir' that use all @@ -776,15 +810,15 @@ def _parse_header(self): def _source_name(self): raise (NotImplementedError) - def _segment_t_start(self, block_index, seg_index): + def _segment_t_start(self, block_index: int, seg_index: int): raise (NotImplementedError) - def _segment_t_stop(self, block_index, seg_index): + def _segment_t_stop(self, block_index: int, seg_index: int): raise (NotImplementedError) ### # signal and channel zone - def _get_signal_size(self, block_index, seg_index, stream_index): + def _get_signal_size(self, block_index: int, seg_index: int, stream_index: int): """ Return the size of a set of AnalogSignals indexed by channel_indexes. @@ -792,7 +826,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index): """ raise (NotImplementedError) - def _get_signal_t_start(self, block_index, seg_index, stream_index): + def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int): """ Return the t_start of a set of AnalogSignals indexed by channel_indexes. @@ -800,8 +834,13 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index): """ raise (NotImplementedError) - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, - stream_index, channel_indexes): + def _get_analogsignal_chunk(self, + block_index: int, + seg_index: int, + i_start: int | None, + i_stop: int | None, + stream_index: int, + channel_indexes: list[int] | None): """ Return the samples from a set of AnalogSignals indexed by stream_index and channel_indexes (local index inner stream). @@ -815,38 +854,51 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, ### # spiketrain and unit zone - def _spike_count(self, block_index, seg_index, spike_channel_index): + def _spike_count(self, block_index: int, seg_index: int, spike_channel_index: int): raise (NotImplementedError) - def _get_spike_timestamps(self, block_index, seg_index, - spike_channel_index, t_start, t_stop): + def _get_spike_timestamps(self, + block_index: int, + seg_index: int, + spike_channel_index: int, + t_start: float | None, + t_stop: float | None): raise (NotImplementedError) - def _rescale_spike_timestamp(self, spike_timestamps, dtype): + def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype): raise (NotImplementedError) ### # spike waveforms zone - def _get_spike_raw_waveforms(self, block_index, seg_index, - spike_channel_index, t_start, t_stop): + def _get_spike_raw_waveforms(self, + block_index: int, + seg_index: int, + spike_channel_index: int, + t_start: float | None, + t_stop: float | None): raise (NotImplementedError) ### # event and epoch zone - def _event_count(self, block_index, seg_index, event_channel_index): + def _event_count(self, block_index: int, seg_index: int, event_channel_index: int): raise (NotImplementedError) - def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): + def _get_event_timestamps(self, + block_index: int, + seg_index: int, + event_channel_index: int, + t_start: float | None, + t_stop: float | None): raise (NotImplementedError) - def _rescale_event_timestamp(self, event_timestamps, dtype): + def _rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype): raise (NotImplementedError) - def _rescale_epoch_duration(self, raw_duration, dtype): + def _rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype): raise (NotImplementedError) -def pprint_vector(vector, lim=8): +def pprint_vector(vector, lim: int = 8): vector = np.asarray(vector) assert vector.ndim == 1 if len(vector) > lim: diff --git a/neo/rawio/examplerawio.py b/neo/rawio/examplerawio.py index 26b4572b7..27710b6fb 100644 --- a/neo/rawio/examplerawio.py +++ b/neo/rawio/examplerawio.py @@ -34,11 +34,14 @@ * copy/paste from neo/test/iotest/test_exampleio.py """ +from __future__ import annotations + +import numpy as np +from pathlib import Path from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype, _spike_channel_dtype, _event_channel_dtype) -import numpy as np class ExampleRawIO(BaseRawIO): @@ -82,7 +85,7 @@ class ExampleRawIO(BaseRawIO): extensions = ['fake'] rawmode = 'one-file' - def __init__(self, filename=''): + def __init__(self, filename: str | Path = ''): BaseRawIO.__init__(self) # note that this filename is ued in self._source_name self.filename = filename @@ -228,19 +231,19 @@ def _parse_header(self): elif c == 1: event_an['nickname'] = 'MrEpoch 1' - def _segment_t_start(self, block_index, seg_index): + def _segment_t_start(self, block_index: int, seg_index: int): # this must return a float scaled in seconds # this t_start will be shared by all objects in the segment # except AnalogSignal all_starts = [[0., 15.], [0., 20., 60.]] return all_starts[block_index][seg_index] - def _segment_t_stop(self, block_index, seg_index): + def _segment_t_stop(self, block_index: int, seg_index: int): # this must return a float scaled in seconds all_stops = [[10., 25.], [10., 30., 70.]] return all_stops[block_index][seg_index] - def _get_signal_size(self, block_index, seg_index, stream_index): + def _get_signal_size(self, block_index: int, seg_index: int, stream_index: int): # We generate fake data in which the two stream signals have the same shape # across all segments (10.0 seconds) # This is not the case for real data, instead you should return the signal @@ -251,7 +254,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index): # except for the case of several sampling rates. return 100000 - def _get_signal_t_start(self, block_index, seg_index, stream_index): + def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int): # This give the t_start of a signal. # Very often this is equal to _segment_t_start but not # always. @@ -264,8 +267,13 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index): # this is not always the case return self._segment_t_start(block_index, seg_index) - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, - stream_index, channel_indexes): + def _get_analogsignal_chunk(self, + block_index: int, + seg_index: int, + i_start: int | None, + i_stop: int | None, + stream_index: int, + channel_indexes: np.ndarray | list | slice | None): # this must return a signal chunk in a signal stream # limited with i_start/i_stop (can be None) # channel_indexes can be None (=all channel in the stream) or a list or numpy.array @@ -305,14 +313,19 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, raw_signals = np.zeros((i_stop - i_start, nb_chan), dtype='int16') return raw_signals - def _spike_count(self, block_index, seg_index, spike_channel_index): + def _spike_count(self, block_index: int, seg_index: int, spike_channel_index: int): # Must return the nb of spikes for given (block_index, seg_index, spike_channel_index) # we are lucky: our units have all the same nb of spikes!! # it is not always the case nb_spikes = 20 return nb_spikes - def _get_spike_timestamps(self, block_index, seg_index, spike_channel_index, t_start, t_stop): + def _get_spike_timestamps(self, + block_index: int, + seg_index: int, + spike_channel_index: int, + t_start: float | None, + t_stop: float | None): # In our IO, timestamp are internally coded 'int64' and they # represent the index of the signals 10kHz # we are lucky: spikes have the same discharge in all segments!! @@ -333,15 +346,19 @@ def _get_spike_timestamps(self, block_index, seg_index, spike_channel_index, t_s return spike_timestamps - def _rescale_spike_timestamp(self, spike_timestamps, dtype): + def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype): # must rescale to seconds, a particular spike_timestamps # with a fixed dtype so the user can choose the precision they want. spike_times = spike_timestamps.astype(dtype) spike_times /= 10000. # because 10kHz return spike_times - def _get_spike_raw_waveforms(self, block_index, seg_index, spike_channel_index, - t_start, t_stop): + def _get_spike_raw_waveforms(self, + block_index: int, + seg_index: int, + spike_channel_index: int, + t_start: float | None, + t_stop: float | None): # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample) # in the original dtype # this must be as fast as possible. @@ -367,7 +384,7 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, spike_channel_index, waveforms = waveforms.reshape(nb_spike, 1, 50) return waveforms - def _event_count(self, block_index, seg_index, event_channel_index): + def _event_count(self, block_index: int, seg_index: int, event_channel_index: int): # event and spike are very similar # we have 2 event channels if event_channel_index == 0: @@ -377,7 +394,12 @@ def _event_count(self, block_index, seg_index, event_channel_index): # epoch channel return 10 - def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): + def _get_event_timestamps(self, + block_index: int, + seg_index: int, + event_channel_index: int, + t_start: float | None, + t_stop: float | None): # the main difference between spike channel and event channel # is that for event channels we have 3D numpy array (timestamp, durations, labels) where # durations must be None for 'event' @@ -408,7 +430,10 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s return timestamp, durations, labels - def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index): + def _rescale_event_timestamp(self, + event_timestamps: np.ndarray, + dtype: np.dtype, + event_channel_index: int): # must rescale to seconds for a particular event_timestamps # with a fixed dtype so the user can choose the precision they want. @@ -416,7 +441,10 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index) event_times = event_timestamps.astype(dtype) return event_times - def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index): + def _rescale_epoch_duration(self, + raw_duration: np.ndarray, + dtype: np.dtype, + event_channel_index: int): # really easy here because in our case it is already in seconds durations = raw_duration.astype(dtype) return durations diff --git a/neo/rawio/neuralynxrawio/ncssections.py b/neo/rawio/neuralynxrawio/ncssections.py index 6f8eacad8..b372478c4 100644 --- a/neo/rawio/neuralynxrawio/ncssections.py +++ b/neo/rawio/neuralynxrawio/ncssections.py @@ -438,7 +438,7 @@ def build_for_ncs_file(ncsMemMap, nlxHdr): # digital lynx style with fractional frequency and micros per samp determined from # block times - elif acqType == "DIGITALLYNX" or acqType == "DIGITALLYNXSX" or acqType == 'CHEETAH64' or acqType == 'RAWDATAFILE': + elif acqType in ["DIGITALLYNX", "DIGITALLYNXSX", 'CHEETAH64', 'CHEETAH560', 'RAWDATAFILE']: nomFreq = nlxHdr['sampling_rate'] nb = NcsSectionsFactory._buildForMaxGap(ncsMemMap, nomFreq) diff --git a/neo/rawio/neuralynxrawio/neuralynxrawio.py b/neo/rawio/neuralynxrawio/neuralynxrawio.py index 8f46bb9d0..1b5b91a2f 100644 --- a/neo/rawio/neuralynxrawio/neuralynxrawio.py +++ b/neo/rawio/neuralynxrawio/neuralynxrawio.py @@ -850,20 +850,21 @@ def get_nse_or_ntt_dtype(info, ext): """ dtype = [('timestamp', 'uint64'), ('channel_id', 'uint32'), ('unit_id', 'uint32')] - # count feature - nb_feature = 0 - for k in info.keys(): - if k.startswith('Feature '): - nb_feature += 1 + # for purpose of dtypes, features in the file are always fixed 8 presently, + # whether mentioned in the header or not. Features may not be listed in the header + # if no feature names are assigned in Neuralynx software. + nb_feature = 8 dtype += [('features', 'int32', (nb_feature,))] - # count sample + # Number of samples are fixed in the file at 32 for .nse 32 * 4 for .ntt. + # WaveformLength may or may not be listed in the file depending on settings + # in the Neuralynx software, so don't try retrieving it. if ext == 'nse': - nb_sample = info['WaveformLength'] + nb_sample = 32 dtype += [('samples', 'int16', (nb_sample,))] elif ext == 'ntt': - nb_sample = info['WaveformLength'] - nb_chan = 4 # check this if not tetrode + nb_sample = 32 + nb_chan = 4 dtype += [('samples', 'int16', (nb_sample, nb_chan))] return dtype diff --git a/neo/rawio/neuralynxrawio/nlxheader.py b/neo/rawio/neuralynxrawio/nlxheader.py index 0f3980390..9bb5fba79 100644 --- a/neo/rawio/neuralynxrawio/nlxheader.py +++ b/neo/rawio/neuralynxrawio/nlxheader.py @@ -97,6 +97,12 @@ def _to_bool(txt): r' At Time: (?P