Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set fixed values and column attrs on autogen class #800

Merged
merged 13 commits into from
Oct 26, 2023
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
- Added `target_tables` attribute to `DynamicTable` to allow users to specify the target table of any predefined
`DynamicTableRegion` columns of a `DynamicTable` subclass. @rly [#971](https://github.com/hdmf-dev/hdmf/pull/971)

### Bug fixes
- Updated custom class generation to handle specs with fixed values and required names. @rly [#800](https://github.com/hdmf-dev/hdmf/pull/800)
- Fixed custom class generation of `DynamicTable` subtypes to set attributes corresponding to column names for correct write. @rly [#800](https://github.com/hdmf-dev/hdmf/pull/800)

## HDMF 3.10.0 (October 3, 2023)

Since version 3.9.1 should have been released as 3.10.0 but failed to release on PyPI and conda-forge, this release
Expand Down
36 changes: 30 additions & 6 deletions src/hdmf/build/classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,19 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
'doc': field_spec['doc']}
if cls._ischild(dtype) and issubclass(parent_cls, Container) and not isinstance(field_spec, LinkSpec):
fields_conf['child'] = True
# if getattr(field_spec, 'value', None) is not None: # TODO set the fixed value on the class?
# fields_conf['settable'] = False
fixed_value = getattr(field_spec, 'value', None)
if fixed_value is not None:
fields_conf['settable'] = False
if isinstance(field_spec, (BaseStorageSpec, LinkSpec)) and field_spec.data_type is not None:
# subgroups, datasets, and links with data types can have fixed names
fixed_name = getattr(field_spec, 'name', None)
if fixed_name is not None:
fields_conf['required_name'] = fixed_name
classdict.setdefault(parent_cls._fieldsname, list()).append(fields_conf)

if fixed_value is not None: # field has fixed value - do not create arg on __init__
return

docval_arg = dict(
name=attr_name,
doc=field_spec.doc,
Expand Down Expand Up @@ -285,17 +294,27 @@ def post_process(cls, classdict, bases, docval_args, spec):
# set default name in docval args if provided
cls._set_default_name(docval_args, spec.default_name)

@classmethod
def _get_attrs_not_to_set_init(cls, classdict, parent_docval_args):
return parent_docval_args

@classmethod
def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name):
# get docval arg names from superclass
base = bases[0]
parent_docval_args = set(arg['name'] for arg in get_docval(base.__init__))
new_args = list()
attrs_to_set = list()
fixed_value_attrs_to_set = list()
attrs_not_to_set = cls._get_attrs_not_to_set_init(classdict, parent_docval_args)
for attr_name, field_spec in not_inherited_fields.items():
# store arguments for fields that are not in the superclass and not in the superclass __init__ docval
# so that they are set after calling base.__init__
if attr_name not in parent_docval_args:
new_args.append(attr_name)
# except for fields that have fixed values -- these are set at the class level
fixed_value = getattr(field_spec, 'value', None)
if fixed_value is not None:
fixed_value_attrs_to_set.append(attr_name)
elif attr_name not in attrs_not_to_set:
attrs_to_set.append(attr_name)

@docval(*docval_args, allow_positional=AllowPositional.WARNING)
def __init__(self, **kwargs):
Expand All @@ -305,7 +324,7 @@ def __init__(self, **kwargs):
# remove arguments from kwargs that correspond to fields that are new (not inherited)
# set these arguments after calling base.__init__
new_kwargs = dict()
for f in new_args:
for f in attrs_to_set:
new_kwargs[f] = popargs(f, kwargs) if f in kwargs else None

# NOTE: the docval of some constructors do not include all of the fields. the constructor may set
Expand All @@ -319,6 +338,11 @@ def __init__(self, **kwargs):
for f, arg_val in new_kwargs.items():
setattr(self, f, arg_val)

# set the fields that have fixed values using the fields dict directly
# because the setters do not allow setting the value
for f in fixed_value_attrs_to_set:
self.fields[f] = getattr(not_inherited_fields[f], 'value')

classdict['__init__'] = __init__


Expand Down
9 changes: 9 additions & 0 deletions src/hdmf/common/io/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,12 @@ def post_process(cls, classdict, bases, docval_args, spec):
columns = classdict.get('__columns__')
if columns is not None:
classdict['__columns__'] = tuple(columns)

@classmethod
def _get_attrs_not_to_set_init(cls, classdict, parent_docval_args):
# exclude columns from the args that are set in __init__
attrs_not_to_set = parent_docval_args.copy()
if "__columns__" in classdict:
column_names = [column_conf["name"] for column_conf in classdict["__columns__"]]
attrs_not_to_set.update(column_names)
return attrs_not_to_set
5 changes: 5 additions & 0 deletions src/hdmf/spec/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,11 @@ def data_type_inc(self):
''' The data type of target specification '''
return self.get(_target_type_key)

@property
def data_type(self):
''' The data type of target specification '''
return self.get(_target_type_key)

def is_many(self):
return self.quantity not in (1, ZERO_OR_ONE)

Expand Down
152 changes: 151 additions & 1 deletion tests/unit/build_tests/test_classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,156 @@ def test_multi_container_spec_one_or_more_ok(self):
assert len(multi.bars) == 1


class TestDynamicContainerFixedValue(TestCase):

def setUp(self):
self.baz_spec = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Baz',
attributes=[AttributeSpec(name='attr1', doc='a string attribute', dtype='text', value="fixed")]
)
self.type_map = create_test_type_map([], {}) # empty typemap
self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog
self.spec_catalog.register_spec(self.baz_spec, 'extension.yaml')

def test_init_docval(self):
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class
expected_args = {'name'} # 'attr1' should not be included
received_args = set()
for x in get_docval(cls.__init__):
received_args.add(x['name'])
self.assertSetEqual(expected_args, received_args)

def test_init_fields(self):
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class
self.assertEqual(cls.get_fields_conf(), ({'name': 'attr1', 'doc': 'a string attribute', 'settable': False},))

def test_init_object(self):
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class
obj = cls(name="test")
self.assertEqual(obj.attr1, "fixed")

def test_set_value(self):
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class
obj = cls(name="test")
with self.assertRaises(AttributeError):
obj.attr1 = "new"


class TestDynamicContainerIncludingFixedName(TestCase):

def setUp(self):
self.baz_spec1 = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Baz1',
)
self.baz_spec2 = GroupSpec(
doc='A test dataset specification with a data type',
data_type_def='Baz2',
)
self.baz_spec3 = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Baz3',
groups=[
GroupSpec(
doc='A composition inside with a fixed name',
name="my_baz1",
data_type_inc='Baz1'
),
],
datasets=[
DatasetSpec(
doc='A composition inside with a fixed name',
name="my_baz2",
data_type_inc='Baz2'
),
],
links=[
LinkSpec(
doc='A composition inside with a fixed name',
name="my_baz1_link",
target_type='Baz1'
),
],
)
self.type_map = create_test_type_map([], {}) # empty typemap
self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog
self.spec_catalog.register_spec(self.baz_spec1, 'extension.yaml')
self.spec_catalog.register_spec(self.baz_spec2, 'extension.yaml')
self.spec_catalog.register_spec(self.baz_spec3, 'extension.yaml')

def test_gen_parent_class(self):
baz1_cls = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) # generate the class
baz2_cls = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE)
baz3_cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE)
self.assertEqual(get_docval(baz3_cls.__init__), (
{'name': 'name', 'type': str, 'doc': 'the name of this container'},
{'name': 'my_baz1', 'doc': 'A composition inside with a fixed name', 'type': baz1_cls},
{'name': 'my_baz2', 'doc': 'A composition inside with a fixed name', 'type': baz2_cls},
{'name': 'my_baz1_link', 'doc': 'A composition inside with a fixed name', 'type': baz1_cls},
))

def test_init_fields(self):
cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE) # generate the class
self.assertEqual(cls.get_fields_conf(), (
{
'name': 'my_baz1',
'doc': 'A composition inside with a fixed name',
'child': True,
'required_name': 'my_baz1'
},
{
'name': 'my_baz2',
'doc': 'A composition inside with a fixed name',
'child': True,
'required_name': 'my_baz2'
},
{
'name': 'my_baz1_link',
'doc': 'A composition inside with a fixed name',
'required_name': 'my_baz1_link'
},
))

def test_set_field(self):
baz1_cls = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) # generate the class
baz2_cls = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE)
baz3_cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE)
baz1 = baz1_cls(name="my_baz1")
baz2 = baz2_cls(name="my_baz2")
baz1_link = baz1_cls(name="my_baz1_link")
baz3 = baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link)
self.assertEqual(baz3.my_baz1, baz1)
self.assertEqual(baz3.my_baz2, baz2)
self.assertEqual(baz3.my_baz1_link, baz1_link)

def test_set_field_bad(self):
baz1_cls = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) # generate the class
baz2_cls = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE)
baz3_cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE)

baz1 = baz1_cls(name="test")
baz2 = baz2_cls(name="my_baz2")
baz1_link = baz1_cls(name="my_baz1_link")
msg = "Field 'my_baz1' on Baz3 must be named 'my_baz1'."
with self.assertRaisesWith(ValueError, msg):
baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link)

baz1 = baz1_cls(name="my_baz1")
baz2 = baz2_cls(name="test")
baz1_link = baz1_cls(name="my_baz1_link")
msg = "Field 'my_baz2' on Baz3 must be named 'my_baz2'."
with self.assertRaisesWith(ValueError, msg):
baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link)

baz1 = baz1_cls(name="my_baz1")
baz2 = baz2_cls(name="my_baz2")
baz1_link = baz1_cls(name="test")
msg = "Field 'my_baz1_link' on Baz3 must be named 'my_baz1_link'."
with self.assertRaisesWith(ValueError, msg):
baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link)


class TestGetClassSeparateNamespace(TestCase):

def setUp(self):
Expand Down Expand Up @@ -899,7 +1049,7 @@ def test_process_field_spec_link(self):
spec=GroupSpec('dummy', 'doc')
)

expected = {'__fields__': [{'name': 'attr3', 'doc': 'a link'}]}
expected = {'__fields__': [{'name': 'attr3', 'doc': 'a link', 'required_name': 'attr3'}]}
self.assertDictEqual(classdict, expected)

def test_post_process_fixed_name(self):
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/common/test_generate_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ def test_dynamic_table_region_non_dtr_target(self):
self.TestDTRTable(name='test_dtr_table', description='my table',
target_tables={'optional_col3': test_table})

def test_attribute(self):
test_table = self.TestTable(name='test_table', description='my test table')
assert test_table.my_col is not None
assert test_table.indexed_col is not None
assert test_table.my_col is test_table['my_col']
assert test_table.indexed_col is test_table['indexed_col'].target

def test_roundtrip(self):
# NOTE this does not use H5RoundTripMixin because this requires custom validation
test_table = self.TestTable(name='test_table', description='my test table')
Expand Down