From b76672ad2c057999c51f9144e03a88aafc3713cd Mon Sep 17 00:00:00 2001 From: Andrew Davison Date: Mon, 29 Jan 2024 21:18:55 +0100 Subject: [PATCH] now handles all Neo objects, including ChannelViews and RegionsOfInterest --- neo/core/baseneo.py | 4 +- neo/core/container.py | 11 +- neo/core/regionofinterest.py | 33 +++- neo/core/view.py | 10 +- neo/io/neomatlabio.py | 290 +++++++++++++++++++--------- neo/test/coretest/test_block.py | 5 + neo/test/generate_datasets.py | 260 ++++++++++++++----------- neo/test/iotest/test_neomatlabio.py | 31 ++- 8 files changed, 426 insertions(+), 218 deletions(-) diff --git a/neo/core/baseneo.py b/neo/core/baseneo.py index 060a8b6c5..3d2e4a4d9 100644 --- a/neo/core/baseneo.py +++ b/neo/core/baseneo.py @@ -180,7 +180,7 @@ class BaseNeo: class must have. The tuple can have 2-4 elements. The first element is the attribute name. The second element is the attribute type. - The third element is the number of dimensions + The third element is the number of dimensions (only for numpy arrays and quantities). The fourth element is the dtype of array (only for numpy arrays and quantities). @@ -253,6 +253,8 @@ class attributes. :_recommended_attrs: should append # Attributes that are used for pretty-printing _repr_pretty_attrs_keys_ = ("name", "description", "annotations") + is_view = False + def __init__(self, name=None, description=None, file_origin=None, **annotations): """ diff --git a/neo/core/container.py b/neo/core/container.py index 8c5ecd4fc..07836c747 100644 --- a/neo/core/container.py +++ b/neo/core/container.py @@ -251,16 +251,19 @@ def _data_child_containers(self): """ Containers for child objects that have data and have a single parent. """ - return tuple([_container_name(child) for child in - self._data_child_objects]) + # the following construction removes the duplicate 'regionsofinterest' + # while preserving the child order (which `set()` would not do) + # I don't know if preserving the author is important, but I'm playing it safe + return tuple({_container_name(child): None for child in + self._data_child_objects}.keys()) @property def _child_containers(self): """ Containers for child objects with a single parent. """ - return tuple([_container_name(child) for child in - self._child_objects]) + return tuple({_container_name(child): None for child in + self._child_objects}.keys()) @property def _single_children(self): diff --git a/neo/core/regionofinterest.py b/neo/core/regionofinterest.py index 458fb7067..3de9b624f 100644 --- a/neo/core/regionofinterest.py +++ b/neo/core/regionofinterest.py @@ -10,8 +10,9 @@ class RegionOfInterest(BaseNeo): _parent_objects = ('Group',) _parent_attrs = ('group',) _necessary_attrs = ( - ('obj', ('ImageSequence', ), 1), + ('image_sequence', ('ImageSequence', ), 1), ) + is_view = True def __init__(self, image_sequence, name=None, description=None, file_origin=None, **annotations): super().__init__(name=name, description=description, @@ -22,6 +23,16 @@ def __init__(self, image_sequence, name=None, description=None, file_origin=None raise ValueError("Can only take a RegionOfInterest of an ImageSequence") self.image_sequence = image_sequence + def _get_obj(self): + # for consistency with ChannelView + return self.image_sequence + + def _set_obj(self, value): + assert isinstance(value, ImageSequence) + self.image_sequence = value + + obj = property(fget=_get_obj, fset=_set_obj) + def resolve(self): """ Return a signal from within this region of the underlying ImageSequence. @@ -44,6 +55,13 @@ class CircularRegionOfInterest(RegionOfInterest): Radius of the ROI in pixels """ + _necessary_attrs = ( + ('image_sequence', ('ImageSequence', ), 1), + ('x', int), + ('y', int), + ('radius', int) + ) + def __init__(self, image_sequence, x, y, radius, name=None, description=None, file_origin=None, **annotations): super().__init__(image_sequence, name, description, file_origin, **annotations) @@ -94,6 +112,14 @@ class RectangularRegionOfInterest(RegionOfInterest): Height (y-direction) of the ROI in pixels """ + _necessary_attrs = ( + ('image_sequence', ('ImageSequence', ), 1), + ('x', int), + ('y', int), + ('width', int), + ('height', int) + ) + def __init__(self, image_sequence, x, y, width, height, name=None, description=None, file_origin=None, **annotations): super().__init__(image_sequence, name, description, file_origin, **annotations) @@ -139,6 +165,11 @@ class PolygonRegionOfInterest(RegionOfInterest): of the vertices of the polygon """ + _necessary_attrs = ( + ('image_sequence', ('ImageSequence', ), 1), + ('vertices', list), + ) + def __init__(self, image_sequence, *vertices, name=None, description=None, file_origin=None, **annotations): super().__init__(image_sequence, name, description, file_origin, **annotations) diff --git a/neo/core/view.py b/neo/core/view.py index 35e733a07..55c7d2a28 100644 --- a/neo/core/view.py +++ b/neo/core/view.py @@ -30,12 +30,14 @@ class ChannelView(BaseNeo): Note: Any other additional arguments are assumed to be user-specific metadata and stored in :attr:`annotations`. """ - _parent_objects = ('Segment',) - _parent_attrs = ('segment',) + _parent_objects = ('Group',) + _parent_attrs = ('group',) _necessary_attrs = ( + ('obj', ('AnalogSignal', 'IrregularlySampledSignal'), 1), ('index', np.ndarray, 1, np.dtype('i')), - ('obj', ('AnalogSignal', 'IrregularlySampledSignal'), 1) ) + is_view = True + # "mask" would be an alternative name, proposing "index" for # backwards-compatibility with ChannelIndex @@ -73,7 +75,7 @@ def shape(self): return (self.obj.shape[0], self.index.size) def _get_arr_ann_length(self): - return self.shape[-1] + return self.index.size def array_annotate(self, **array_annotations): self.array_annotations.update(array_annotations) diff --git a/neo/io/neomatlabio.py b/neo/io/neomatlabio.py index 756869efc..2d0fe46b8 100644 --- a/neo/io/neomatlabio.py +++ b/neo/io/neomatlabio.py @@ -12,6 +12,7 @@ Author: sgarcia, Robert Pröpper """ +from collections.abc import Mapping from datetime import datetime import re @@ -22,15 +23,41 @@ from neo.io.baseio import BaseIO -from neo.core import (Block, Segment, AnalogSignal, - IrregularlySampledSignal, Event, - Epoch, SpikeTrain, - Group, ImageSequence, - objectnames, class_by_name) +from neo.core import ( + Block, + Segment, + AnalogSignal, + IrregularlySampledSignal, + Event, + Epoch, + SpikeTrain, + Group, + ImageSequence, + ChannelView, + RectangularRegionOfInterest, + CircularRegionOfInterest, + PolygonRegionOfInterest, + objectnames, + class_by_name, +) +from neo.core.regionofinterest import RegionOfInterest +from neo.core.baseneo import _container_name + + +def get_classname_from_container_name(container_name, struct): + if container_name == "regionsofinterest": + if "radius" in struct._fieldnames: + return "CircularRegionOfInterest" + elif "vertices" in struct._fieldnames: + return "PolygonRegionOfInterest" + else: + return "RectangularRegionOfInterest" + else: + for classname in objectnames: + if _container_name(classname) == container_name: + return classname -classname_lower_to_upper = {} -for k in objectnames: - classname_lower_to_upper[k.lower()] = k +PY_NONE = "Py_None" class NeoMatlabIO(BaseIO): @@ -175,11 +202,25 @@ class NeoMatlabIO(BaseIO): w.write(blocks[0]) """ + is_readable = True is_writable = True - supported_objects = [Block, Segment, AnalogSignal, IrregularlySampledSignal, - Epoch, Event, SpikeTrain, Group, ImageSequence] + supported_objects = [ + Block, + Segment, + AnalogSignal, + IrregularlySampledSignal, + Epoch, + Event, + SpikeTrain, + Group, + ImageSequence, + ChannelView, + RectangularRegionOfInterest, + CircularRegionOfInterest, + PolygonRegionOfInterest, + ] readable_objects = [Block] writeable_objects = [Block] @@ -188,10 +229,10 @@ class NeoMatlabIO(BaseIO): read_params = {Block: []} write_params = {Block: []} - name = 'neomatlab' - extensions = ['mat'] + name = "neomatlab" + extensions = ["mat"] - mode = 'file' + mode = "file" def __init__(self, filename=None): """ @@ -202,13 +243,16 @@ def __init__(self, filename=None): """ import scipy - if Version(scipy.version.version) < Version('0.12.0'): - raise ImportError("your scipy version is too old to support " - + "MatlabIO, you need at least 0.12.0. " - + "You have %s" % scipy.version.version) + if Version(scipy.version.version) < Version("0.12.0"): + raise ImportError( + "your scipy version is too old to support " + + "MatlabIO, you need at least 0.12.0. " + + "You have %s" % scipy.version.version + ) BaseIO.__init__(self) self.filename = filename + self._refs = {} def read_block(self, lazy=False): """ @@ -216,17 +260,16 @@ def read_block(self, lazy=False): """ import scipy.io - assert not lazy, 'Does not support lazy' - d = scipy.io.loadmat(self.filename, struct_as_record=False, - squeeze_me=True, mat_dtype=True) - if 'block' not in d: - self.logger.exception('No block in ' + self.filename) + assert not lazy, "Does not support lazy" + + d = scipy.io.loadmat(self.filename, struct_as_record=False, squeeze_me=True, mat_dtype=True) + if "block" not in d: + self.logger.exception("No block in " + self.filename) return None - bl_struct = d['block'] - bl = self.create_ob_from_struct( - bl_struct, 'Block') + bl_struct = d["block"] + bl = self.create_ob_from_struct(bl_struct, "Block") self._resolve_references(bl) bl.check_relationships() return bl @@ -234,14 +277,15 @@ def read_block(self, lazy=False): def write_block(self, bl, **kargs): """ Arguments: - bl: the block to b saved + bl: the block to be saved """ import scipy.io + bl_struct = self.create_struct_from_obj(bl) for seg in bl.segments: seg_struct = self.create_struct_from_obj(seg) - bl_struct['segments'].append(seg_struct) + bl_struct["segments"].append(seg_struct) for container_name in seg._child_containers: for child_obj in getattr(seg, container_name): @@ -250,65 +294,95 @@ def write_block(self, bl, **kargs): for group in bl.groups: group_structure = self.create_struct_from_obj(group) - bl_struct['groups'].append(group_structure) + bl_struct["groups"].append(group_structure) for container_name in group._child_containers: for child_obj in getattr(group, container_name): - group_structure[container_name].append(id(child_obj)) + if isinstance(child_obj, (ChannelView, RegionOfInterest)): + child_struct = self.create_struct_from_view(child_obj) + group_structure[container_name].append(child_struct) + else: + group_structure[container_name].append(id(child_obj)) - scipy.io.savemat(self.filename, {'block': bl_struct}, oned_as='row') + if kargs.get("debug", False): + breakpoint() + scipy.io.savemat(self.filename, {"block": bl_struct}, oned_as="row") + + def _get_matlab_value(self, ob, attrname): + units = None + if hasattr(ob, "_quantity_attr") and ob._quantity_attr == attrname: + units = ob.dimensionality.string + value = ob.magnitude + else: + try: + value = getattr(ob, attrname) + except AttributeError: + value = ob[attrname] + if isinstance(value, pq.Quantity): + units = value.dimensionality.string + value = value.magnitude + elif isinstance(value, datetime): + value = str(value) + elif isinstance(value, Mapping): + new_value = {} + for key in value: + subvalue, subunits = self._get_matlab_value(value, key) + if subvalue is not None: + new_value[key] = subvalue + if subunits: + new_value[f"{key}_units"] = subunits + elif attrname == "annotations": + # In general we don't send None to MATLAB + # but we make an exception for annotations. + # However, we have to save then retrieve some + # special value as actual `None` is ignored by default. + new_value[key] = PY_NONE + value = new_value + return value, units def create_struct_from_obj(self, ob): struct = {"neo_id": id(ob)} # relationship - for childname in getattr(ob, '_child_containers', []): - supported_containers = [subob.__name__.lower() + 's' for subob in - self.supported_objects] + for childname in getattr(ob, "_child_containers", []): + supported_containers = [_container_name(subob.__name__) for subob in self.supported_objects] if childname in supported_containers: struct[childname] = [] # attributes all_attrs = list(ob._all_attrs) - if hasattr(ob, 'annotations'): - all_attrs.append(('annotations', type(ob.annotations))) + if hasattr(ob, "annotations"): + all_attrs.append(("annotations", type(ob.annotations))) - for i, attr in enumerate(all_attrs): + for attr in all_attrs: attrname, attrtype = attr[0], attr[1] + attr_value, attr_units = self._get_matlab_value(ob, attrname) + if attr_value is not None: + struct[attrname] = attr_value + if attr_units: + struct[attrname + "_units"] = attr_units + return struct - if (hasattr(ob, '_quantity_attr') and - ob._quantity_attr == attrname): - struct[attrname] = ob.magnitude - struct[attrname + '_units'] = ob.dimensionality.string - continue - - if not (attrname in ob.annotations or hasattr(ob, attrname)): - continue - if getattr(ob, attrname) is None: - continue - - if attrtype == pq.Quantity: - # ndim = attr[2] - struct[attrname] = getattr(ob, attrname).magnitude - struct[attrname + '_units'] = getattr( - ob, attrname).dimensionality.string - elif attrtype == datetime: - struct[attrname] = str(getattr(ob, attrname)) - else: - struct[attrname] = getattr(ob, attrname) - + def create_struct_from_view(self, ob): + # for "view" objects (ChannelView and RegionOfInterest), we just store + # a reference to the object (AnalogSignal, ImageSequence) that the view + # points to + struct = self.create_struct_from_obj(ob) + obj_name = ob._necessary_attrs[0][0] # this is fragile, better to add an attribute _view_attr + viewed_obj = getattr(ob, obj_name) + struct[obj_name] = id(viewed_obj) + struct["viewed_classname"] = viewed_obj.__class__.__name__ return struct def create_ob_from_struct(self, struct, classname): cl = class_by_name[classname] # ~ if is_quantity: - if hasattr(cl, '_quantity_attr'): + if hasattr(cl, "_quantity_attr"): quantity_attr = cl._quantity_attr arr = getattr(struct, quantity_attr) # ~ data_complement = dict(units=str(struct.units)) - data_complement = dict(units=str( - getattr(struct, quantity_attr + '_units'))) + data_complement = dict(units=str(getattr(struct, quantity_attr + "_units"))) if "sampling_rate" in (at[0] for at in cl._necessary_attrs): # put fake value for now, put correct value later data_complement["sampling_rate"] = 0 * pq.kHz @@ -341,55 +415,65 @@ def create_ob_from_struct(self, struct, classname): ob = cl(times, arr, **data_complement) else: ob = cl(arr, **data_complement) + elif cl.is_view: + kwargs = {} + for i, attr in enumerate(cl._necessary_attrs): + value = getattr(struct, attr[0]) + if i == 0: + # this is a bit hacky, should really add an attribute _view_attr to ChannelView and RegionOfInterest + assert isinstance(value, int) # object id + kwargs[attr[0]] = _Ref(identifier=value, target_class_name=struct.viewed_classname) + else: + if attr[1] == np.ndarray and isinstance(value, int): + value = np.array([value]) + kwargs[attr[0]] = value + ob = cl(**kwargs) else: ob = cl() for attrname in struct._fieldnames: # check children - if attrname in getattr(ob, '_child_containers', []): + if attrname in getattr(ob, "_child_containers", []): child_struct = getattr(struct, attrname) - child_class_name = classname_lower_to_upper[attrname[:-1]] try: # try must only surround len() or other errors are captured child_len = len(child_struct) except TypeError: # strange scipy.io behavior: if len is 1 there is no len() + child_struct = [child_struct] + child_len = 1 + + for c in range(child_len): + child_class_name = get_classname_from_container_name(attrname, child_struct[c]) if classname == "Group": - child = _Ref(child_struct, child_class_name) + if child_class_name == ("ChannelView") or "RegionOfInterest" in child_class_name: + child = self.create_ob_from_struct(child_struct[c], child_class_name) + else: + child = _Ref(child_struct[c], child_class_name) else: - child = self.create_ob_from_struct( - child_struct, - child_class_name) + child = self.create_ob_from_struct(child_struct[c], child_class_name) getattr(ob, attrname.lower()).append(child) - else: - for c in range(child_len): - if classname == "Group": - child = _Ref(child_struct[c], child_class_name) - else: - child = self.create_ob_from_struct( - child_struct[c], - child_class_name) - getattr(ob, attrname.lower()).append(child) continue # attributes - if attrname.endswith('_units') or attrname == 'units': + if attrname.endswith("_units") or attrname == "units": # linked with another field continue - if hasattr(cl, '_quantity_attr') and cl._quantity_attr == attrname: + if hasattr(cl, "_quantity_attr") and cl._quantity_attr == attrname: continue - item = getattr(struct, attrname) + if cl.is_view and attrname in ("obj", "index", "image_sequence", "x", "y", "radius", "width", "height", "vertices"): + continue - attributes = cl._necessary_attrs + cl._recommended_attrs \ - + (('annotations', dict),) + item = getattr(struct, attrname) + attributes = cl._necessary_attrs + cl._recommended_attrs + (("annotations", dict),) dict_attributes = dict([(a[0], a[1:]) for a in attributes]) if attrname in dict_attributes: attrtype = dict_attributes[attrname][0] if attrtype == datetime: - m = r'(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+).(\d+)' + m = r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+).(\d+)" r = re.findall(m, str(item)) if len(r) == 1: item = datetime(*[int(e) for e in r[0]]) @@ -397,17 +481,26 @@ def create_ob_from_struct(self, struct, classname): item = None elif attrtype == np.ndarray: dt = dict_attributes[attrname][2] - item = item.astype(dt) + try: + item = item.astype(dt) + except AttributeError: + # it seems arrays of length 1 are stored as scalars + item = np.array([item], dtype=dt) elif attrtype == pq.Quantity: ndim = dict_attributes[attrname][1] - units = str(getattr(struct, attrname + '_units')) + units = str(getattr(struct, attrname + "_units")) if ndim == 0: item = pq.Quantity(item, units) else: item = pq.Quantity(item, units) elif attrtype == dict: - # FIXME: works but doesn't convert nested struct to dict - item = {fn: getattr(item, fn) for fn in item._fieldnames} + new_item = {} + for fn in item._fieldnames: + value = getattr(item, fn) + if value == PY_NONE: + value = None + new_item[fn] = value + item = new_item else: item = attrtype(item) @@ -427,17 +520,34 @@ def _resolve_references(self, bl): for grp in bl.groups: for container_name in grp._child_containers: container = getattr(grp, container_name) - for i, ref in enumerate(container): - assert isinstance(ref, _Ref) - container[i] = obj_lookup[ref.identifier] + for i, item in enumerate(container): + if isinstance(item, _Ref): + assert isinstance(item.identifier, (int, np.integer)) + # A reference to an object that already exists + container[i] = obj_lookup[item.identifier] + else: + # ChannelView and RegionOfInterest + assert item.is_view + assert isinstance(item.obj, _Ref) + item.obj = obj_lookup[item.obj.identifier] class _Ref: - def __init__(self, identifier, target_class_name): self.identifier = identifier - self.target_cls = class_by_name[target_class_name] + if target_class_name: + self.target_cls = class_by_name[target_class_name] + else: + self.target_cls = None @property def proxy_for(self): return self.target_cls + + @property + def data_children_recur(self): + return [] + + @property + def container_children_recur(self): + return [] diff --git a/neo/test/coretest/test_block.py b/neo/test/coretest/test_block.py index c82f3e99a..16b21d90c 100644 --- a/neo/test/coretest/test_block.py +++ b/neo/test/coretest/test_block.py @@ -70,12 +70,17 @@ def test__filter_none(self): targ.extend(seg.spiketrains) targ.extend(seg.imagesequences) chv_names = set([]) + roi_names = set([]) for grp in block.groups: for grp1 in grp.walk(): for chv in grp1.channelviews: if chv.name not in chv_names: targ.append(chv) chv_names.add(chv.name) + for roi in grp1.regionsofinterest: + if roi.name not in roi_names: + targ.append(roi) + roi_names.add(roi.name) res1 = block.filter() res2 = block.filter({}) diff --git a/neo/test/generate_datasets.py b/neo/test/generate_datasets.py index e44ed4c15..6bdc78728 100644 --- a/neo/test/generate_datasets.py +++ b/neo/test/generate_datasets.py @@ -1,6 +1,6 @@ -''' +""" Generate datasets for testing -''' +""" from datetime import datetime import random @@ -9,10 +9,19 @@ from numpy.random import rand import quantities as pq -from neo.core import (AnalogSignal, Block, Epoch, Event, IrregularlySampledSignal, Group, - Segment, SpikeTrain, ImageSequence, ChannelView, - CircularRegionOfInterest, RectangularRegionOfInterest, - PolygonRegionOfInterest) +from neo.core import ( + AnalogSignal, + Block, + Epoch, + Event, + IrregularlySampledSignal, + Group, + Segment, + SpikeTrain, + ImageSequence, + ChannelView, + CircularRegionOfInterest +) TEST_ANNOTATIONS = [1, 0, 1.5, "this is a test", datetime.fromtimestamp(424242424), None] @@ -28,12 +37,7 @@ def random_datetime(min_year=1990, max_year=datetime.now().year): def random_annotations(n=1): - annotation_generators = ( - random.random, - random_datetime, - random_string, - lambda: None - ) + annotation_generators = (random.random, random_datetime, random_string, lambda: None) annotations = {} for i in range(n): var_name = random_string(6) @@ -55,8 +59,8 @@ def random_signal(name=None, **annotations): name=name or random_string(), file_origin=random_string(), description=random_string(100), - array_annotations=None, # todo - **annotations + array_annotations=None, # todo + **annotations, ) return obj @@ -75,8 +79,8 @@ def random_irreg_signal(name=None, **annotations): name=name or random_string(), file_origin=random_string(), description=random_string(100), - array_annotations=None, # todo - **annotations + array_annotations=None, # todo + **annotations, ) return obj @@ -92,8 +96,8 @@ def random_event(name=None, **annotations): labels=labels, units="ms", name=name or random_string(), - array_annotations=None, # todo - **annotations + array_annotations=None, # todo + **annotations, ) return obj @@ -109,8 +113,8 @@ def random_epoch(): labels=labels, units="ms", name=random_string(), - array_annotations=None, # todo - **random_annotations(3) + array_annotations=None, # todo + **random_annotations(3), ) return obj @@ -126,8 +130,28 @@ def random_spiketrain(name=None, **annotations): t_stop=times[-1] + random.uniform(0, 5), units="ms", name=name or random_string(), - array_annotations=None, # todo - **annotations + array_annotations=None, # todo + **annotations, + ) + return obj + + +def random_image_sequence(name=None, **annotations): + pixels_i = random.randint(2, 7) + pixels_j = random.randint(2, 7) + seq_length = random.randint(20, 200) + if len(annotations) == 0: + annotations = random_annotations(5) + obj = ImageSequence( + np.random.uniform(size=(seq_length, pixels_i, pixels_j)), + t_start=random.uniform(0, 10) * pq.ms, + sampling_rate=random.uniform(0.1, 10) * pq.kHz, + spatial_scale=random.uniform(0.1, 10) * pq.um, + name=name or random_string(), + file_origin=random_string(), + description=random_string(100), + array_annotations=None, # todo + **annotations, ) return obj @@ -139,7 +163,7 @@ def random_segment(): file_origin=random_string(20), file_datetime=random_datetime(), rec_datetime=random_datetime(), - **random_annotations(4) + **random_annotations(4), ) n_sigs = random.randint(0, 5) for i in range(n_sigs): @@ -156,7 +180,10 @@ def random_segment(): n_spiketrains = random.randint(0, 20) for i in range(n_spiketrains): seg.spiketrains.append(random_spiketrain()) - # todo: add some ImageSequence and ROI objects + n_imgs = random.randint(0, 5) + for i in range(n_imgs): + seg.imagesequences.append(random_image_sequence()) + # todo: add some and ROI objects return seg @@ -169,9 +196,7 @@ def random_group(candidates): else: k = random.randint(1, len(candidates)) objects = random.sample(candidates, k) - obj = Group(objects=objects, - name=random_string(), - **random_annotations(5)) + obj = Group(objects=objects, name=random_string(), **random_annotations(5)) return obj @@ -180,17 +205,20 @@ def random_channelview(signal): if n_channels > 2: view_size = np.random.randint(1, n_channels - 1) index = np.random.choice(np.arange(signal.shape[1]), view_size, replace=False) - obj = ChannelView( - signal, - index, - name=random_string(), - **random_annotations(3) - ) + obj = ChannelView(signal, index, name=random_string(), **random_annotations(3)) return obj else: return None +def random_roi(imgseq): + x = np.random.randint(imgseq.shape[1]) + y = np.random.randint(imgseq.shape[2]) + radius = np.random.uniform() * imgseq.shape[1] + obj = CircularRegionOfInterest(imgseq, x, y, radius, name=random_string(), **random_annotations(3)) + return obj + + def random_block(): block = Block( name=random_string(10), @@ -198,7 +226,7 @@ def random_block(): file_origin=random_string(20), file_datetime=random_datetime(), rec_datetime=random_datetime(), - **random_annotations(6) + **random_annotations(6), ) n_seg = random.randint(0, 5) for i in range(n_seg): @@ -213,6 +241,12 @@ def random_block(): chv = random_channelview(child) if chv: views.append(chv) + elif isinstance(child, ImageSequence): + PROB_IMGSEQ_HAS_ROI = 0.5 + if np.random.random_sample() < PROB_IMGSEQ_HAS_ROI: + roi = random_roi(child) + if roi: + views.append(roi) children.extend(views) n_groups = random.randint(0, 5) for i in range(n_groups): @@ -224,116 +258,109 @@ def random_block(): def simple_block(): - block = Block( - name="test block", - species="rat", - brain_region="cortex" - ) + block = Block(name="test block", species="rat", brain_region="cortex") block.segments = [ - Segment(name="test segment #1", - cell_type="spiny stellate"), - Segment(name="test segment #2", - cell_type="pyramidal", - thing="amajig") + Segment(name="test segment #1", cell_type="spiny stellate"), + Segment(name="test segment #2", cell_type="pyramidal", thing="amajig"), ] - block.segments[0].analogsignals.extend(( - random_signal(name="signal #1 in segment #1", thing="wotsit"), - random_signal(name="signal #2 in segment #1", thing="frooble"), - )) - block.segments[1].analogsignals.extend(( - random_signal(name="signal #1 in segment #2", thing="amajig"), - )) - block.segments[1].irregularlysampledsignals.extend(( - random_irreg_signal(name="signal #1 in segment #2", thing="amajig"), - )) - block.segments[0].events.extend(( - random_event(name="event array #1 in segment #1", thing="frooble"), - )) - block.segments[1].events.extend(( - random_event(name="event array #1 in segment #2", thing="wotsit"), - )) - block.segments[0].spiketrains.extend(( - random_spiketrain(name="spiketrain #1 in segment #1", thing="frooble"), - random_spiketrain(name="spiketrain #2 in segment #1", thing="wotsit") - )) + block.segments[0].analogsignals.extend( + ( + random_signal(name="signal #1 in segment #1", thing="wotsit"), + random_signal(name="signal #2 in segment #1", thing="frooble"), + ) + ) + block.segments[1].analogsignals.extend((random_signal(name="signal #1 in segment #2", thing="amajig"),)) + block.segments[1].irregularlysampledsignals.extend( + (random_irreg_signal(name="signal #1 in segment #2", thing="amajig"),) + ) + block.segments[0].events.extend((random_event(name="event array #1 in segment #1", thing="frooble"),)) + block.segments[1].events.extend((random_event(name="event array #1 in segment #2", thing="wotsit"),)) + block.segments[0].spiketrains.extend( + ( + random_spiketrain(name="spiketrain #1 in segment #1", thing="frooble"), + random_spiketrain(name="spiketrain #2 in segment #1", thing="wotsit"), + ) + ) return block -def generate_one_simple_block(block_name='block_0', nb_segment=3, supported_objects=[], **kws): +def generate_one_simple_block(block_name="block_0", nb_segment=3, supported_objects=[], **kws): if supported_objects and Block not in supported_objects: - raise ValueError('Block must be in supported_objects') + raise ValueError("Block must be in supported_objects") bl = Block() # name = block_name) objects = supported_objects if Segment in objects: for s in range(nb_segment): - seg = generate_one_simple_segment(seg_name="seg" + str(s), supported_objects=objects, - **kws) + seg = generate_one_simple_segment(seg_name="seg" + str(s), supported_objects=objects, **kws) bl.segments.append(seg) bl.check_relationships() return bl -def generate_one_simple_segment(seg_name='segment 0', supported_objects=[], nb_analogsignal=4, - t_start=0. * pq.s, sampling_rate=10 * pq.kHz, duration=6. * pq.s, - - nb_spiketrain=6, spikerate_range=[.5 * pq.Hz, 12 * pq.Hz], - - event_types={'stim': ['a', 'b', 'c', 'd'], - 'enter_zone': ['one', 'two'], - 'color': ['black', 'yellow', 'green'], }, - event_size_range=[5, 20], - - epoch_types={'animal state': ['Sleep', 'Freeze', 'Escape'], - 'light': ['dark', 'lighted']}, - epoch_duration_range=[.5, 3.], - # this should be multiplied by pq.s, no? - - array_annotations={'valid': np.array([True, False]), - 'number': np.array(range(5))} - - ): +def generate_one_simple_segment( + seg_name="segment 0", + supported_objects=[], + nb_analogsignal=4, + t_start=0.0 * pq.s, + sampling_rate=10 * pq.kHz, + duration=6.0 * pq.s, + nb_spiketrain=6, + spikerate_range=[0.5 * pq.Hz, 12 * pq.Hz], + event_types={ + "stim": ["a", "b", "c", "d"], + "enter_zone": ["one", "two"], + "color": ["black", "yellow", "green"], + }, + event_size_range=[5, 20], + epoch_types={"animal state": ["Sleep", "Freeze", "Escape"], "light": ["dark", "lighted"]}, + epoch_duration_range=[0.5, 3.0], + # this should be multiplied by pq.s, no? + array_annotations={"valid": np.array([True, False]), "number": np.array(range(5))}, +): if supported_objects and Segment not in supported_objects: - raise ValueError('Segment must be in supported_objects') + raise ValueError("Segment must be in supported_objects") seg = Segment(name=seg_name) if AnalogSignal in supported_objects: for a in range(nb_analogsignal): - anasig = AnalogSignal(rand(int((sampling_rate * duration).simplified)), - sampling_rate=sampling_rate, - t_start=t_start, units=pq.mV, channel_index=a, - name='sig %d for segment %s' % (a, seg.name)) + anasig = AnalogSignal( + rand(int((sampling_rate * duration).simplified)), + sampling_rate=sampling_rate, + t_start=t_start, + units=pq.mV, + channel_index=a, + name="sig %d for segment %s" % (a, seg.name), + ) seg.analogsignals.append(anasig) if SpikeTrain in supported_objects: for s in range(nb_spiketrain): - spikerate = rand() * np.diff(spikerate_range) - spikerate += spikerate_range[0].magnitude - # spikedata = rand(int((spikerate*duration).simplified))*duration - # sptr = SpikeTrain(spikedata, - # t_start=t_start, t_stop=t_start+duration) - # #, name = 'spiketrain %d'%s) + spikerate = rand() * np.diff(spikerate_range)[0] + spikerate += spikerate_range[0].item() spikes = rand(int((spikerate * duration).simplified)) spikes.sort() # spikes are supposed to be an ascending sequence sptr = SpikeTrain(spikes * duration, t_start=t_start, t_stop=t_start + duration) - sptr.annotations['channel_index'] = s + sptr.annotations["channel_index"] = s # Randomly generate array_annotations from given options - arr_ann = {key: value[(rand(len(spikes)) * len(value)).astype('i')] for (key, value) in - array_annotations.items()} + arr_ann = { + key: value[(rand(len(spikes)) * len(value)).astype("i")] for (key, value) in array_annotations.items() + } sptr.array_annotate(**arr_ann) seg.spiketrains.append(sptr) if Event in supported_objects: for name, labels in event_types.items(): - evt_size = rand() * np.diff(event_size_range) + evt_size = rand() * np.diff(event_size_range)[0] evt_size += event_size_range[0] evt_size = int(evt_size) - labels = np.array(labels, dtype='U') - labels = labels[(rand(evt_size) * len(labels)).astype('i')] + labels = np.array(labels, dtype="U") + labels = labels[(rand(evt_size) * len(labels)).astype("i")] evt = Event(times=rand(evt_size) * duration, labels=labels) # Randomly generate array_annotations from given options - arr_ann = {key: value[(rand(evt_size) * len(value)).astype('i')] for (key, value) in - array_annotations.items()} + arr_ann = { + key: value[(rand(evt_size) * len(value)).astype("i")] for (key, value) in array_annotations.items() + } evt.array_annotate(**arr_ann) seg.events.append(evt) @@ -348,17 +375,20 @@ def generate_one_simple_segment(seg_name='segment 0', supported_objects=[], nb_a dur += epoch_duration_range[0] durations.append(dur) t = t + dur - labels = np.array(labels, dtype='U') - labels = labels[(rand(len(times)) * len(labels)).astype('i')] + labels = np.array(labels, dtype="U") + labels = labels[(rand(len(times)) * len(labels)).astype("i")] assert len(times) == len(durations) assert len(times) == len(labels) - epc = Epoch(times=pq.Quantity(times, units=pq.s), - durations=pq.Quantity(durations, units=pq.s), - labels=labels,) - assert epc.times.dtype == 'float' + epc = Epoch( + times=pq.Quantity(times, units=pq.s), + durations=pq.Quantity(durations, units=pq.s), + labels=labels, + ) + assert epc.times.dtype == "float" # Randomly generate array_annotations from given options - arr_ann = {key: value[(rand(len(times)) * len(value)).astype('i')] for (key, value) in - array_annotations.items()} + arr_ann = { + key: value[(rand(len(times)) * len(value)).astype("i")] for (key, value) in array_annotations.items() + } epc.array_annotate(**arr_ann) seg.epochs.append(epc) @@ -370,7 +400,7 @@ def generate_one_simple_segment(seg_name='segment 0', supported_objects=[], nb_a def generate_from_supported_objects(supported_objects): if not supported_objects: - raise ValueError('No objects specified') + raise ValueError("No objects specified") objects = supported_objects if Block in supported_objects: higher = generate_one_simple_block(supported_objects=objects) diff --git a/neo/test/iotest/test_neomatlabio.py b/neo/test/iotest/test_neomatlabio.py index 175755ab2..3dfb3042b 100644 --- a/neo/test/iotest/test_neomatlabio.py +++ b/neo/test/iotest/test_neomatlabio.py @@ -2,6 +2,7 @@ Tests of neo.io.neomatlabio """ +import os import unittest from numpy.testing import assert_array_equal import quantities as pq @@ -10,6 +11,8 @@ from neo.core.irregularlysampledsignal import IrregularlySampledSignal from neo import Block, Segment, SpikeTrain, ImageSequence, Group from neo.test.iotest.common_io_test import BaseTestIO +from neo.test.generate_datasets import random_block + from neo.io.neomatlabio import NeoMatlabIO try: @@ -29,7 +32,7 @@ def test_write_read_single_spike(self): block1 = Block(name="test_neomatlabio") seg = Segment('segment1') spiketrain1 = SpikeTrain([1] * pq.s, t_stop=10 * pq.s, sampling_rate=1 * pq.Hz) - spiketrain1.annotate(yep='yop') + spiketrain1.annotate(yep='yop', yip=None) sig1 = AnalogSignal([4, 5, 6] * pq.A, sampling_period=1 * pq.ms) irrsig1 = IrregularlySampledSignal([0, 1, 2] * pq.ms, [4, 5, 6] * pq.A) img_sequence_array = [[[column for column in range(2)] for _ in range(2)] @@ -49,7 +52,7 @@ def test_write_read_single_spike(self): # write block filename = self.get_local_path('matlabiotestfile.mat') io1 = self.ioclass(filename) - io1.write_block(block1) + io1.write_block(block1, debug=False) # read block io2 = self.ioclass(filename) @@ -71,13 +74,35 @@ def test_write_read_single_spike(self): # test annotations spiketrain2 = block2.segments[0].spiketrains[0] - assert 'yep' in spiketrain2.annotations assert spiketrain2.annotations['yep'] == 'yop' + assert spiketrain2.annotations['yip'] is None # test group retrieval group2 = block2.groups[0] assert_array_equal(group1.analogsignals[0], group2.analogsignals[0]) + def test_write_read_random_blocks(self): + for i in range(10): + # generate random block + block1 = random_block() + + # write block to file + filename_orig = self.get_local_path(f"matlabio_randomtest_orig_{i}.mat") + io1 = self.ioclass(filename_orig) + io1.write_block(block1, debug=False) + + # read block + io2 = self.ioclass(filename_orig) + block2 = io2.read_block() + + filename_roundtripped = self.get_local_path(f"matlabio_randomtest_roundtrip_{i}.mat") + io3 = self.ioclass(filename_roundtripped) + io3.write_block(block2) + + # the actual contents will differ since we're using Python object id as identifiers + # but at least the file size should be the same + assert os.stat(filename_orig).st_size == os.stat(filename_roundtripped).st_size + if __name__ == "__main__": unittest.main()