From 1417d3a695bb2675084fe0246f58e74bf9ba819e Mon Sep 17 00:00:00 2001 From: mavaylon1 Date: Wed, 10 Apr 2024 08:48:56 -0700 Subject: [PATCH] tests --- src/hdmf/build/classgenerator.py | 2 +- tests/unit/build_tests/test_classgenerator.py | 56 ++++++++++++++++--- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/src/hdmf/build/classgenerator.py b/src/hdmf/build/classgenerator.py index e0b0b3b61..ac3bf1d42 100644 --- a/src/hdmf/build/classgenerator.py +++ b/src/hdmf/build/classgenerator.py @@ -322,7 +322,7 @@ def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): fixed_value_attrs_to_set.append(attr_name) elif attr_name not in attrs_not_to_set: attrs_to_set.append(attr_name) - + @docval(*docval_args, allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): original_kwargs = dict(kwargs) diff --git a/tests/unit/build_tests/test_classgenerator.py b/tests/unit/build_tests/test_classgenerator.py index 4838ee827..28e47f1ad 100644 --- a/tests/unit/build_tests/test_classgenerator.py +++ b/tests/unit/build_tests/test_classgenerator.py @@ -84,6 +84,14 @@ def test_no_generators(self): class TestPostInitGetClass(TestCase): def setUp(self): + def post_init_method(self, **kwargs): + attr1 = kwargs['attr1'] + if attr1<10: + msg = "attr1 should be >=10" + raise ValueError(msg) + self.post_init=post_init_method + + def test_post_init(self): spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', @@ -103,21 +111,51 @@ def setUp(self): ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) - self.type_map = TypeMap(namespace_catalog) - + type_map = TypeMap(namespace_catalog) - def test_post_init(self): - def post_init_method(self, **kwargs): - attr1 = kwargs['attr1'] - if attr1<10: - msg = "attr1 should be >=10" - raise ValueError(msg) - cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE, post_init_method) + cls = type_map.get_dt_container_cls('Baz', CORE_NAMESPACE, self.post_init) with self.assertRaises(ValueError): cls(name='instance', attr1=9) + def test_multi_container_post_init(self): + bar_spec = GroupSpec( + doc='A test group specification with a data type', + data_type_def='Bar', + datasets=[ + DatasetSpec( + doc='a dataset', + dtype='int', + name='data', + attributes=[AttributeSpec(name='attr2', doc='an integer attribute', dtype='int')] + ) + ], + attributes=[AttributeSpec(name='attr1', doc='a string attribute', dtype='text')]) + + multi_spec = GroupSpec(doc='A test extension that contains a multi', + data_type_def='Multi', + groups=[GroupSpec(data_type_inc=bar_spec, doc='test multi', quantity='*')], + attributes=[AttributeSpec(name='attr1', doc='a float attribute', dtype='float')]) + + spec_catalog = SpecCatalog() + spec_catalog.register_spec(bar_spec, 'test.yaml') + spec_catalog.register_spec(multi_spec, 'test.yaml') + namespace = SpecNamespace( + doc='a test namespace', + name=CORE_NAMESPACE, + schema=[{'source': 'test.yaml'}], + version='0.1.0', + catalog=spec_catalog + ) + namespace_catalog = NamespaceCatalog() + namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) + type_map = TypeMap(namespace_catalog) + Multi = type_map.get_dt_container_cls('Multi', CORE_NAMESPACE, self.post_init) + + with self.assertRaises(ValueError): + Multi(name='instance', attr1=9.1) + class TestDynamicContainer(TestCase): def setUp(self):