diff --git a/CHANGELOG.md b/CHANGELOG.md index 2331deda2..278d64838 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ ### Bug fixes - Fixed issue with custom class generation when a spec has a `name`. @rly [#1006](https://github.com/hdmf-dev/hdmf/pull/1006) +- Fixed issue where `ElementIdentifiers` data could be set to non-integer values. @rly [#1009](https://github.com/hdmf-dev/hdmf/pull/1009) + ## HDMF 3.11.0 (October 30, 2023) ### Enhancements diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 767bf4f2e..2e2b56979 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -15,7 +15,7 @@ from . import register_class, EXP_NAMESPACE from ..container import Container, Data from ..data_utils import DataIO, AbstractDataChunkIterator -from ..utils import docval, getargs, ExtenderMeta, popargs, pystr, AllowPositional +from ..utils import docval, getargs, ExtenderMeta, popargs, pystr, AllowPositional, check_type from ..term_set import TermSetWrapper @@ -211,8 +211,8 @@ class ElementIdentifiers(Data): """ @docval({'name': 'name', 'type': str, 'doc': 'the name of this ElementIdentifiers'}, - {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a 1D dataset containing identifiers', - 'default': list()}, + {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a 1D dataset containing integer identifiers', + 'default': list(), 'shape': (None,)}, allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -237,6 +237,20 @@ def __eq__(self, other): # Find all matching locations return np.in1d(self.data, search_ids).nonzero()[0] + def _validate_new_data(self, data): + # NOTE this may not cover all the many AbstractDataChunkIterator edge cases + if (isinstance(data, AbstractDataChunkIterator) or + (hasattr(data, "data") and isinstance(data.data, AbstractDataChunkIterator))): + if not np.issubdtype(data.dtype, np.integer): + raise ValueError("ElementIdentifiers must contain integers") + elif hasattr(data, "__len__") and len(data): + self._validate_new_data_element(data[0]) + + def _validate_new_data_element(self, arg): + if not check_type(arg, int): + raise ValueError("ElementIdentifiers must contain integers") + super()._validate_new_data_element(arg) + @register_class('DynamicTable') class DynamicTable(Container): diff --git a/src/hdmf/container.py b/src/hdmf/container.py index ccc7a3df7..8420805cb 100644 --- a/src/hdmf/container.py +++ b/src/hdmf/container.py @@ -763,6 +763,8 @@ class Data(AbstractContainer): def __init__(self, **kwargs): data = popargs('data', kwargs) super().__init__(**kwargs) + + self._validate_new_data(data) self.__data = data @property @@ -822,6 +824,7 @@ def get(self, args): return self.data[args] def append(self, arg): + self._validate_new_data_element(arg) self.__data = append_data(self.__data, arg) def extend(self, arg): @@ -831,8 +834,23 @@ def extend(self, arg): :param arg: The iterable to add to the end of this VectorData """ + self._validate_new_data(arg) self.__data = extend_data(self.__data, arg) + def _validate_new_data(self, data): + """Function to validate a new array that will be set or added to data. Raises an error if the data is invalid. + + Subclasses should override this function to perform class-specific validation. + """ + pass + + def _validate_new_data_element(self, arg): + """Function to validate a new value that will be added to the data. Raises an error if the data is invalid. + + Subclasses should override this function to perform class-specific validation. + """ + pass + class DataRegion(Data): diff --git a/src/hdmf/data_utils.py b/src/hdmf/data_utils.py index 941c3f8c7..f1eee655f 100644 --- a/src/hdmf/data_utils.py +++ b/src/hdmf/data_utils.py @@ -1061,6 +1061,8 @@ def __len__(self): return self.__shape[0] if not self.valid: raise InvalidDataIOError("Cannot get length of data. Data is not valid.") + if isinstance(self.data, AbstractDataChunkIterator): + return self.data.maxshape[0] return len(self.data) def __bool__(self): diff --git a/src/hdmf/utils.py b/src/hdmf/utils.py index e2686912a..fcf2fe6a5 100644 --- a/src/hdmf/utils.py +++ b/src/hdmf/utils.py @@ -67,7 +67,7 @@ def get_docval_macro(key=None): return tuple(__macros[key]) -def __type_okay(value, argtype, allow_none=False): +def check_type(value, argtype, allow_none=False): """Check a value against a type The difference between this function and :py:func:`isinstance` is that @@ -87,7 +87,7 @@ def __type_okay(value, argtype, allow_none=False): return allow_none if isinstance(argtype, str): if argtype in __macros: - return __type_okay(value, __macros[argtype], allow_none=allow_none) + return check_type(value, __macros[argtype], allow_none=allow_none) elif argtype == 'uint': return __is_uint(value) elif argtype == 'int': @@ -106,7 +106,7 @@ def __type_okay(value, argtype, allow_none=False): return __is_bool(value) return isinstance(value, argtype) elif isinstance(argtype, tuple) or isinstance(argtype, list): - return any(__type_okay(value, i) for i in argtype) + return any(check_type(value, i) for i in argtype) else: # argtype is None return True @@ -279,7 +279,7 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True, # we can use this to unwrap the dataset/attribute to use the "item" for docval to validate the type. argval = argval.value if enforce_type: - if not __type_okay(argval, arg['type']): + if not check_type(argval, arg['type']): if argval is None: fmt_val = (argname, __format_type(arg['type'])) type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val) @@ -336,7 +336,7 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True, # we can use this to unwrap the dataset/attribute to use the "item" for docval to validate the type. argval = argval.value if enforce_type: - if not __type_okay(argval, arg['type'], arg['default'] is None or arg.get('allow_none', False)): + if not check_type(argval, arg['type'], arg['default'] is None or arg.get('allow_none', False)): if argval is None and arg['default'] is None: fmt_val = (argname, __format_type(arg['type'])) type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val) @@ -613,7 +613,7 @@ def dec(func): msg = 'docval for {}: enum checking cannot be used with arg type {}'.format(a['name'], a['type']) raise Exception(msg) # check that enum allowed values are allowed by arg type - if any([not __type_okay(x, a['type']) for x in a['enum']]): + if any([not check_type(x, a['type']) for x in a['enum']]): msg = ('docval for {}: enum values are of types not allowed by arg type (got {}, ' 'expected {})'.format(a['name'], [type(x) for x in a['enum']], a['type'])) raise Exception(msg) diff --git a/tests/unit/common/test_table.py b/tests/unit/common/test_table.py index 3f358d22a..7246a8ba8 100644 --- a/tests/unit/common/test_table.py +++ b/tests/unit/common/test_table.py @@ -1392,6 +1392,38 @@ def test_identifier_search_with_bad_ids(self): _ = (self.e == 'test') +class TestBadElementIdentifiers(TestCase): + + def test_bad_dtype(self): + with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"): + ElementIdentifiers(name='ids', data=["1", "2"]) + + with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"): + ElementIdentifiers(name='ids', data=np.array(["1", "2"])) + + with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"): + ElementIdentifiers(name='ids', data=[1.0, 2.0]) + + def test_dci_int_ok(self): + a = np.arange(30) + dci = DataChunkIterator(data=a, buffer_size=1) + e = ElementIdentifiers(name='ids', data=dci) # test that no error is raised + self.assertIs(e.data, dci) + + def test_dci_float_bad(self): + a = np.arange(30.0) + dci = DataChunkIterator(data=a, buffer_size=1) + with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"): + ElementIdentifiers(name='ids', data=dci) + + def test_dataio_dci_ok(self): + a = np.arange(30) + dci = DataChunkIterator(data=a, buffer_size=1) + dio = H5DataIO(dci) + e = ElementIdentifiers(name='ids', data=dio) # test that no error is raised + self.assertIs(e.data, dio) + + class SubTable(DynamicTable): __columns__ = ( diff --git a/tests/unit/utils_test/test_core_DataIO.py b/tests/unit/utils_test/test_core_DataIO.py index 00941cb0e..778dd2617 100644 --- a/tests/unit/utils_test/test_core_DataIO.py +++ b/tests/unit/utils_test/test_core_DataIO.py @@ -8,12 +8,6 @@ class DataIOTests(TestCase): - def setUp(self): - pass - - def tearDown(self): - pass - def test_copy(self): obj = DataIO(data=[1., 2., 3.]) obj_copy = copy(obj)