From c14d37aed40225e02e4f290d9ab23d93d471de0b Mon Sep 17 00:00:00 2001 From: mavaylon1 Date: Wed, 10 Apr 2024 08:14:52 -0700 Subject: [PATCH] checkpoint --- src/hdmf/build/classgenerator.py | 11 +++++------ tests/unit/build_tests/test_classgenerator.py | 18 +----------------- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/src/hdmf/build/classgenerator.py b/src/hdmf/build/classgenerator.py index c5de3b20d..e0b0b3b61 100644 --- a/src/hdmf/build/classgenerator.py +++ b/src/hdmf/build/classgenerator.py @@ -1,7 +1,6 @@ from copy import deepcopy from datetime import datetime, date from collections.abc import Callable -import types as tp import numpy as np @@ -88,11 +87,8 @@ def generate_class(self, **kwargs): + str(e) + " Please define that type before defining '%s'." % name) cls = ExtenderMeta(data_type, tuple(bases), classdict) + cls.post_init_method = post_init_method - if post_init_method is not None: - cls.post_init_method = tp.MethodType(post_init_method, cls) # set as bounded method - else: - cls.post_init_method = post_init_method # set to None return cls @@ -358,7 +354,6 @@ def __init__(self, **kwargs): if self.post_init_method is not None: self.post_init_method(**original_kwargs) - classdict['__init__'] = __init__ @@ -433,6 +428,7 @@ def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): def __init__(self, **kwargs): # store the values passed to init for each MCI attribute so that they can be added # after calling __init__ + original_kwargs = dict(kwargs) new_kwargs = list() for field_clsconf in classdict['__clsconf__']: attr_name = field_clsconf['attr'] @@ -460,5 +456,8 @@ def __init__(self, **kwargs): add_method = getattr(self, new_kwarg['add_method_name']) add_method(new_kwarg['value']) + if self.post_init_method is not None: + self.post_init_method(**original_kwargs) + # override __init__ classdict['__init__'] = __init__ diff --git a/tests/unit/build_tests/test_classgenerator.py b/tests/unit/build_tests/test_classgenerator.py index 29a3fe175..4838ee827 100644 --- a/tests/unit/build_tests/test_classgenerator.py +++ b/tests/unit/build_tests/test_classgenerator.py @@ -84,22 +84,6 @@ def test_no_generators(self): class TestPostInitGetClass(TestCase): def setUp(self): - # self.bar_spec = GroupSpec( - # doc='A test group specification with a data type', - # data_type_def='Bar', - # datasets=[ - # DatasetSpec( - # doc='a dataset', - # dtype='int', - # name='data', - # attributes=[AttributeSpec(name='attr1', doc='an integer attribute', dtype='int')] - # ) - # ]) - # specs = [self.bar_spec] - # containers = {'Bar': Bar} - # from hdmf.common import get_type_map - # self.type_map = get_type_map() - # self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', @@ -132,7 +116,7 @@ def post_init_method(self, **kwargs): cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE, post_init_method) with self.assertRaises(ValueError): - instance = cls(name='instance', attr1=9) + cls(name='instance', attr1=9) class TestDynamicContainer(TestCase):