diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f3986421..580c4fca6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ - Added `target_tables` attribute to `DynamicTable` to allow users to specify the target table of any predefined `DynamicTableRegion` columns of a `DynamicTable` subclass. @rly [#971](https://github.com/hdmf-dev/hdmf/pull/971) +### Bug fixes +- Updated custom class generation to handle specs with fixed values and required names. @rly [#800](https://github.com/hdmf-dev/hdmf/pull/800) +- Fixed custom class generation of `DynamicTable` subtypes to set attributes corresponding to column names for correct write. @rly [#800](https://github.com/hdmf-dev/hdmf/pull/800) + ## HDMF 3.10.0 (October 3, 2023) Since version 3.9.1 should have been released as 3.10.0 but failed to release on PyPI and conda-forge, this release diff --git a/src/hdmf/build/classgenerator.py b/src/hdmf/build/classgenerator.py index 3ec93e659..6a31f4cec 100644 --- a/src/hdmf/build/classgenerator.py +++ b/src/hdmf/build/classgenerator.py @@ -222,10 +222,19 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i 'doc': field_spec['doc']} if cls._ischild(dtype) and issubclass(parent_cls, Container) and not isinstance(field_spec, LinkSpec): fields_conf['child'] = True - # if getattr(field_spec, 'value', None) is not None: # TODO set the fixed value on the class? - # fields_conf['settable'] = False + fixed_value = getattr(field_spec, 'value', None) + if fixed_value is not None: + fields_conf['settable'] = False + if isinstance(field_spec, (BaseStorageSpec, LinkSpec)) and field_spec.data_type is not None: + # subgroups, datasets, and links with data types can have fixed names + fixed_name = getattr(field_spec, 'name', None) + if fixed_name is not None: + fields_conf['required_name'] = fixed_name classdict.setdefault(parent_cls._fieldsname, list()).append(fields_conf) + if fixed_value is not None: # field has fixed value - do not create arg on __init__ + return + docval_arg = dict( name=attr_name, doc=field_spec.doc, @@ -285,17 +294,27 @@ def post_process(cls, classdict, bases, docval_args, spec): # set default name in docval args if provided cls._set_default_name(docval_args, spec.default_name) + @classmethod + def _get_attrs_not_to_set_init(cls, classdict, parent_docval_args): + return parent_docval_args + @classmethod def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): # get docval arg names from superclass base = bases[0] parent_docval_args = set(arg['name'] for arg in get_docval(base.__init__)) - new_args = list() + attrs_to_set = list() + fixed_value_attrs_to_set = list() + attrs_not_to_set = cls._get_attrs_not_to_set_init(classdict, parent_docval_args) for attr_name, field_spec in not_inherited_fields.items(): # store arguments for fields that are not in the superclass and not in the superclass __init__ docval # so that they are set after calling base.__init__ - if attr_name not in parent_docval_args: - new_args.append(attr_name) + # except for fields that have fixed values -- these are set at the class level + fixed_value = getattr(field_spec, 'value', None) + if fixed_value is not None: + fixed_value_attrs_to_set.append(attr_name) + elif attr_name not in attrs_not_to_set: + attrs_to_set.append(attr_name) @docval(*docval_args, allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): @@ -305,7 +324,7 @@ def __init__(self, **kwargs): # remove arguments from kwargs that correspond to fields that are new (not inherited) # set these arguments after calling base.__init__ new_kwargs = dict() - for f in new_args: + for f in attrs_to_set: new_kwargs[f] = popargs(f, kwargs) if f in kwargs else None # NOTE: the docval of some constructors do not include all of the fields. the constructor may set @@ -319,6 +338,11 @@ def __init__(self, **kwargs): for f, arg_val in new_kwargs.items(): setattr(self, f, arg_val) + # set the fields that have fixed values using the fields dict directly + # because the setters do not allow setting the value + for f in fixed_value_attrs_to_set: + self.fields[f] = getattr(not_inherited_fields[f], 'value') + classdict['__init__'] = __init__ diff --git a/src/hdmf/common/io/table.py b/src/hdmf/common/io/table.py index 446c613e0..50395ba24 100644 --- a/src/hdmf/common/io/table.py +++ b/src/hdmf/common/io/table.py @@ -111,3 +111,12 @@ def post_process(cls, classdict, bases, docval_args, spec): columns = classdict.get('__columns__') if columns is not None: classdict['__columns__'] = tuple(columns) + + @classmethod + def _get_attrs_not_to_set_init(cls, classdict, parent_docval_args): + # exclude columns from the args that are set in __init__ + attrs_not_to_set = parent_docval_args.copy() + if "__columns__" in classdict: + column_names = [column_conf["name"] for column_conf in classdict["__columns__"]] + attrs_not_to_set.update(column_names) + return attrs_not_to_set diff --git a/src/hdmf/spec/spec.py b/src/hdmf/spec/spec.py index cdc041c7b..f383fd34a 100644 --- a/src/hdmf/spec/spec.py +++ b/src/hdmf/spec/spec.py @@ -816,6 +816,11 @@ def data_type_inc(self): ''' The data type of target specification ''' return self.get(_target_type_key) + @property + def data_type(self): + ''' The data type of target specification ''' + return self.get(_target_type_key) + def is_many(self): return self.quantity not in (1, ZERO_OR_ONE) diff --git a/tests/unit/build_tests/test_classgenerator.py b/tests/unit/build_tests/test_classgenerator.py index 3bc0bf7f9..5635b12d1 100644 --- a/tests/unit/build_tests/test_classgenerator.py +++ b/tests/unit/build_tests/test_classgenerator.py @@ -431,6 +431,156 @@ def test_multi_container_spec_one_or_more_ok(self): assert len(multi.bars) == 1 +class TestDynamicContainerFixedValue(TestCase): + + def setUp(self): + self.baz_spec = GroupSpec( + doc='A test group specification with a data type', + data_type_def='Baz', + attributes=[AttributeSpec(name='attr1', doc='a string attribute', dtype='text', value="fixed")] + ) + self.type_map = create_test_type_map([], {}) # empty typemap + self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog + self.spec_catalog.register_spec(self.baz_spec, 'extension.yaml') + + def test_init_docval(self): + cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class + expected_args = {'name'} # 'attr1' should not be included + received_args = set() + for x in get_docval(cls.__init__): + received_args.add(x['name']) + self.assertSetEqual(expected_args, received_args) + + def test_init_fields(self): + cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class + self.assertEqual(cls.get_fields_conf(), ({'name': 'attr1', 'doc': 'a string attribute', 'settable': False},)) + + def test_init_object(self): + cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class + obj = cls(name="test") + self.assertEqual(obj.attr1, "fixed") + + def test_set_value(self): + cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class + obj = cls(name="test") + with self.assertRaises(AttributeError): + obj.attr1 = "new" + + +class TestDynamicContainerIncludingFixedName(TestCase): + + def setUp(self): + self.baz_spec1 = GroupSpec( + doc='A test group specification with a data type', + data_type_def='Baz1', + ) + self.baz_spec2 = GroupSpec( + doc='A test dataset specification with a data type', + data_type_def='Baz2', + ) + self.baz_spec3 = GroupSpec( + doc='A test group specification with a data type', + data_type_def='Baz3', + groups=[ + GroupSpec( + doc='A composition inside with a fixed name', + name="my_baz1", + data_type_inc='Baz1' + ), + ], + datasets=[ + DatasetSpec( + doc='A composition inside with a fixed name', + name="my_baz2", + data_type_inc='Baz2' + ), + ], + links=[ + LinkSpec( + doc='A composition inside with a fixed name', + name="my_baz1_link", + target_type='Baz1' + ), + ], + ) + self.type_map = create_test_type_map([], {}) # empty typemap + self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog + self.spec_catalog.register_spec(self.baz_spec1, 'extension.yaml') + self.spec_catalog.register_spec(self.baz_spec2, 'extension.yaml') + self.spec_catalog.register_spec(self.baz_spec3, 'extension.yaml') + + def test_gen_parent_class(self): + baz1_cls = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) # generate the class + baz2_cls = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE) + baz3_cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE) + self.assertEqual(get_docval(baz3_cls.__init__), ( + {'name': 'name', 'type': str, 'doc': 'the name of this container'}, + {'name': 'my_baz1', 'doc': 'A composition inside with a fixed name', 'type': baz1_cls}, + {'name': 'my_baz2', 'doc': 'A composition inside with a fixed name', 'type': baz2_cls}, + {'name': 'my_baz1_link', 'doc': 'A composition inside with a fixed name', 'type': baz1_cls}, + )) + + def test_init_fields(self): + cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE) # generate the class + self.assertEqual(cls.get_fields_conf(), ( + { + 'name': 'my_baz1', + 'doc': 'A composition inside with a fixed name', + 'child': True, + 'required_name': 'my_baz1' + }, + { + 'name': 'my_baz2', + 'doc': 'A composition inside with a fixed name', + 'child': True, + 'required_name': 'my_baz2' + }, + { + 'name': 'my_baz1_link', + 'doc': 'A composition inside with a fixed name', + 'required_name': 'my_baz1_link' + }, + )) + + def test_set_field(self): + baz1_cls = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) # generate the class + baz2_cls = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE) + baz3_cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE) + baz1 = baz1_cls(name="my_baz1") + baz2 = baz2_cls(name="my_baz2") + baz1_link = baz1_cls(name="my_baz1_link") + baz3 = baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link) + self.assertEqual(baz3.my_baz1, baz1) + self.assertEqual(baz3.my_baz2, baz2) + self.assertEqual(baz3.my_baz1_link, baz1_link) + + def test_set_field_bad(self): + baz1_cls = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) # generate the class + baz2_cls = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE) + baz3_cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE) + + baz1 = baz1_cls(name="test") + baz2 = baz2_cls(name="my_baz2") + baz1_link = baz1_cls(name="my_baz1_link") + msg = "Field 'my_baz1' on Baz3 must be named 'my_baz1'." + with self.assertRaisesWith(ValueError, msg): + baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link) + + baz1 = baz1_cls(name="my_baz1") + baz2 = baz2_cls(name="test") + baz1_link = baz1_cls(name="my_baz1_link") + msg = "Field 'my_baz2' on Baz3 must be named 'my_baz2'." + with self.assertRaisesWith(ValueError, msg): + baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link) + + baz1 = baz1_cls(name="my_baz1") + baz2 = baz2_cls(name="my_baz2") + baz1_link = baz1_cls(name="test") + msg = "Field 'my_baz1_link' on Baz3 must be named 'my_baz1_link'." + with self.assertRaisesWith(ValueError, msg): + baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link) + + class TestGetClassSeparateNamespace(TestCase): def setUp(self): @@ -899,7 +1049,7 @@ def test_process_field_spec_link(self): spec=GroupSpec('dummy', 'doc') ) - expected = {'__fields__': [{'name': 'attr3', 'doc': 'a link'}]} + expected = {'__fields__': [{'name': 'attr3', 'doc': 'a link', 'required_name': 'attr3'}]} self.assertDictEqual(classdict, expected) def test_post_process_fixed_name(self): diff --git a/tests/unit/common/test_generate_table.py b/tests/unit/common/test_generate_table.py index 8d76e651d..7f7d7da40 100644 --- a/tests/unit/common/test_generate_table.py +++ b/tests/unit/common/test_generate_table.py @@ -228,6 +228,13 @@ def test_dynamic_table_region_non_dtr_target(self): self.TestDTRTable(name='test_dtr_table', description='my table', target_tables={'optional_col3': test_table}) + def test_attribute(self): + test_table = self.TestTable(name='test_table', description='my test table') + assert test_table.my_col is not None + assert test_table.indexed_col is not None + assert test_table.my_col is test_table['my_col'] + assert test_table.indexed_col is test_table['indexed_col'].target + def test_roundtrip(self): # NOTE this does not use H5RoundTripMixin because this requires custom validation test_table = self.TestTable(name='test_table', description='my test table')