Skip to content

Commit

Permalink
Validate that int data is used for ElementIdentifiers on init, append…
Browse files Browse the repository at this point in the history
…, extend (#1009)

* Validate elementids is int on set and modify

* Update changelog
  • Loading branch information
rly authored Dec 8, 2023
1 parent af13e72 commit 97260bc
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
### Bug fixes
- Fixed issue with custom class generation when a spec has a `name`. @rly [#1006](https://github.com/hdmf-dev/hdmf/pull/1006)

- Fixed issue where `ElementIdentifiers` data could be set to non-integer values. @rly [#1009](https://github.com/hdmf-dev/hdmf/pull/1009)

## HDMF 3.11.0 (October 30, 2023)

### Enhancements
Expand Down
20 changes: 17 additions & 3 deletions src/hdmf/common/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from . import register_class, EXP_NAMESPACE
from ..container import Container, Data
from ..data_utils import DataIO, AbstractDataChunkIterator
from ..utils import docval, getargs, ExtenderMeta, popargs, pystr, AllowPositional
from ..utils import docval, getargs, ExtenderMeta, popargs, pystr, AllowPositional, check_type
from ..term_set import TermSetWrapper


Expand Down Expand Up @@ -211,8 +211,8 @@ class ElementIdentifiers(Data):
"""

@docval({'name': 'name', 'type': str, 'doc': 'the name of this ElementIdentifiers'},
{'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a 1D dataset containing identifiers',
'default': list()},
{'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a 1D dataset containing integer identifiers',
'default': list(), 'shape': (None,)},
allow_positional=AllowPositional.WARNING)
def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand All @@ -237,6 +237,20 @@ def __eq__(self, other):
# Find all matching locations
return np.in1d(self.data, search_ids).nonzero()[0]

def _validate_new_data(self, data):
# NOTE this may not cover all the many AbstractDataChunkIterator edge cases
if (isinstance(data, AbstractDataChunkIterator) or
(hasattr(data, "data") and isinstance(data.data, AbstractDataChunkIterator))):
if not np.issubdtype(data.dtype, np.integer):
raise ValueError("ElementIdentifiers must contain integers")
elif hasattr(data, "__len__") and len(data):
self._validate_new_data_element(data[0])

def _validate_new_data_element(self, arg):
if not check_type(arg, int):
raise ValueError("ElementIdentifiers must contain integers")
super()._validate_new_data_element(arg)


@register_class('DynamicTable')
class DynamicTable(Container):
Expand Down
18 changes: 18 additions & 0 deletions src/hdmf/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ class Data(AbstractContainer):
def __init__(self, **kwargs):
data = popargs('data', kwargs)
super().__init__(**kwargs)

self._validate_new_data(data)
self.__data = data

@property
Expand Down Expand Up @@ -822,6 +824,7 @@ def get(self, args):
return self.data[args]

def append(self, arg):
self._validate_new_data_element(arg)
self.__data = append_data(self.__data, arg)

def extend(self, arg):
Expand All @@ -831,8 +834,23 @@ def extend(self, arg):
:param arg: The iterable to add to the end of this VectorData
"""
self._validate_new_data(arg)
self.__data = extend_data(self.__data, arg)

def _validate_new_data(self, data):
"""Function to validate a new array that will be set or added to data. Raises an error if the data is invalid.
Subclasses should override this function to perform class-specific validation.
"""
pass

def _validate_new_data_element(self, arg):
"""Function to validate a new value that will be added to the data. Raises an error if the data is invalid.
Subclasses should override this function to perform class-specific validation.
"""
pass


class DataRegion(Data):

Expand Down
2 changes: 2 additions & 0 deletions src/hdmf/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,8 @@ def __len__(self):
return self.__shape[0]
if not self.valid:
raise InvalidDataIOError("Cannot get length of data. Data is not valid.")
if isinstance(self.data, AbstractDataChunkIterator):
return self.data.maxshape[0]
return len(self.data)

def __bool__(self):
Expand Down
12 changes: 6 additions & 6 deletions src/hdmf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_docval_macro(key=None):
return tuple(__macros[key])


def __type_okay(value, argtype, allow_none=False):
def check_type(value, argtype, allow_none=False):
"""Check a value against a type
The difference between this function and :py:func:`isinstance` is that
Expand All @@ -87,7 +87,7 @@ def __type_okay(value, argtype, allow_none=False):
return allow_none
if isinstance(argtype, str):
if argtype in __macros:
return __type_okay(value, __macros[argtype], allow_none=allow_none)
return check_type(value, __macros[argtype], allow_none=allow_none)
elif argtype == 'uint':
return __is_uint(value)
elif argtype == 'int':
Expand All @@ -106,7 +106,7 @@ def __type_okay(value, argtype, allow_none=False):
return __is_bool(value)
return isinstance(value, argtype)
elif isinstance(argtype, tuple) or isinstance(argtype, list):
return any(__type_okay(value, i) for i in argtype)
return any(check_type(value, i) for i in argtype)
else: # argtype is None
return True

Expand Down Expand Up @@ -279,7 +279,7 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True,
# we can use this to unwrap the dataset/attribute to use the "item" for docval to validate the type.
argval = argval.value
if enforce_type:
if not __type_okay(argval, arg['type']):
if not check_type(argval, arg['type']):
if argval is None:
fmt_val = (argname, __format_type(arg['type']))
type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val)
Expand Down Expand Up @@ -336,7 +336,7 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True,
# we can use this to unwrap the dataset/attribute to use the "item" for docval to validate the type.
argval = argval.value
if enforce_type:
if not __type_okay(argval, arg['type'], arg['default'] is None or arg.get('allow_none', False)):
if not check_type(argval, arg['type'], arg['default'] is None or arg.get('allow_none', False)):
if argval is None and arg['default'] is None:
fmt_val = (argname, __format_type(arg['type']))
type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val)
Expand Down Expand Up @@ -613,7 +613,7 @@ def dec(func):
msg = 'docval for {}: enum checking cannot be used with arg type {}'.format(a['name'], a['type'])
raise Exception(msg)
# check that enum allowed values are allowed by arg type
if any([not __type_okay(x, a['type']) for x in a['enum']]):
if any([not check_type(x, a['type']) for x in a['enum']]):
msg = ('docval for {}: enum values are of types not allowed by arg type (got {}, '
'expected {})'.format(a['name'], [type(x) for x in a['enum']], a['type']))
raise Exception(msg)
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/common/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,38 @@ def test_identifier_search_with_bad_ids(self):
_ = (self.e == 'test')


class TestBadElementIdentifiers(TestCase):

def test_bad_dtype(self):
with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"):
ElementIdentifiers(name='ids', data=["1", "2"])

with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"):
ElementIdentifiers(name='ids', data=np.array(["1", "2"]))

with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"):
ElementIdentifiers(name='ids', data=[1.0, 2.0])

def test_dci_int_ok(self):
a = np.arange(30)
dci = DataChunkIterator(data=a, buffer_size=1)
e = ElementIdentifiers(name='ids', data=dci) # test that no error is raised
self.assertIs(e.data, dci)

def test_dci_float_bad(self):
a = np.arange(30.0)
dci = DataChunkIterator(data=a, buffer_size=1)
with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"):
ElementIdentifiers(name='ids', data=dci)

def test_dataio_dci_ok(self):
a = np.arange(30)
dci = DataChunkIterator(data=a, buffer_size=1)
dio = H5DataIO(dci)
e = ElementIdentifiers(name='ids', data=dio) # test that no error is raised
self.assertIs(e.data, dio)


class SubTable(DynamicTable):

__columns__ = (
Expand Down
6 changes: 0 additions & 6 deletions tests/unit/utils_test/test_core_DataIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@

class DataIOTests(TestCase):

def setUp(self):
pass

def tearDown(self):
pass

def test_copy(self):
obj = DataIO(data=[1., 2., 3.])
obj_copy = copy(obj)
Expand Down

0 comments on commit 97260bc

Please sign in to comment.