Skip to content

Commit

Permalink
Update DynamicTableRegion.get_linked_tables to return named tuples (#660
Browse files Browse the repository at this point in the history
)

* Update DynamicTableRegion.get_linked_tables to return named tuples rather than dicts
* Update changelog

Co-authored-by: Ryan Ly <[email protected]>
  • Loading branch information
oruebel and rly authored Jul 29, 2021
1 parent 44e0ba7 commit df31c59
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 37 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# HDMF Changelog

## HDMF 3.1.1 (July 29, 2021)

### Fixes
- Updated the new ``DynamicTableRegion.get_linked_tables`` function (added in 3.1.0) to return lists of ``typing.NamedTuple``
objects rather than lists of dicts. @oruebel (#660)

## HDMF 3.1.0 (July 29, 2021)

### New features
Expand Down
13 changes: 9 additions & 4 deletions src/hdmf/common/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import re
from collections import OrderedDict
from typing import NamedTuple, Union
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -1009,11 +1010,15 @@ def get_linked_tables(self, **kwargs):
from this table via foreign DynamicTableColumns included in this table or in any table that
can be reached through DynamicTableRegion columns
Returns: List of dicts with the following keys:
Returns: List of NamedTuple objects with:
* 'source_table' : The source table containing the DynamicTableRegion column
* 'source_column' : The relevant DynamicTableRegion column in the 'source_table'
* 'target_table' : The target DynamicTable; same as source_column.table.
"""
link_type = NamedTuple('DynamicTableLink',
[('source_table', DynamicTable),
('source_column', Union[DynamicTableRegion, VectorIndex]),
('target_table', DynamicTable)])
curr_tables = [self, ] # Set of tables
other_tables = getargs('other_tables', kwargs)
if other_tables is not None:
Expand All @@ -1023,9 +1028,9 @@ def get_linked_tables(self, **kwargs):
while curr_index < len(curr_tables):
for col_index, col in enumerate(curr_tables[curr_index].columns):
if isinstance(col, DynamicTableRegion):
foreign_cols.append({'source_table': curr_tables[curr_index],
'source_column': col,
'target_table': col.table})
foreign_cols.append(link_type(source_table=curr_tables[curr_index],
source_column=col,
target_table=col.table))
curr_table_visited = False
for t in curr_tables:
if t is col.table:
Expand Down
66 changes: 33 additions & 33 deletions tests/unit/common/test_linkedtables.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,13 @@ def test_get_linked_tables(self):
# check with subcateogries
linked_tables = self.aligned_table.get_linked_tables()
self.assertEqual(len(linked_tables), 2)
self.assertTupleEqual((linked_tables[0]['source_table'].name,
linked_tables[0]['source_column'].name,
linked_tables[0]['target_table'].name),
self.assertTupleEqual((linked_tables[0].source_table.name,
linked_tables[0].source_column.name,
linked_tables[0].target_table.name),
('category0', 'child_table_ref1', 'level0_0'))
self.assertTupleEqual((linked_tables[1]['source_table'].name,
linked_tables[1]['source_column'].name,
linked_tables[1]['target_table'].name),
self.assertTupleEqual((linked_tables[1].source_table.name,
linked_tables[1].source_column.name,
linked_tables[1].target_table.name),
('category1', 'child_table_ref1', 'level0_1'))

def test_get_linked_tables_none(self):
Expand Down Expand Up @@ -305,17 +305,17 @@ def test_get_linked_tables_complex_link(self):
linked_tables = temp_aligned_table.get_linked_tables()
self.assertEqual(len(linked_tables), 2)
for i, v in enumerate([('my_aligned_table', 'a2', 't1'), ('t1', 'c2', 't0')]):
self.assertTupleEqual((linked_tables[i]['source_table'].name,
linked_tables[i]['source_column'].name,
linked_tables[i]['target_table'].name), v)
self.assertTupleEqual((linked_tables[i].source_table.name,
linked_tables[i].source_column.name,
linked_tables[i].target_table.name), v)
# Now, since our main table links to the category table the result should remain the same
# even if we ignore the category table
linked_tables = temp_aligned_table.get_linked_tables(ignore_category_tables=True)
self.assertEqual(len(linked_tables), 2)
for i, v in enumerate([('my_aligned_table', 'a2', 't1'), ('t1', 'c2', 't0')]):
self.assertTupleEqual((linked_tables[i]['source_table'].name,
linked_tables[i]['source_column'].name,
linked_tables[i]['target_table'].name), v)
self.assertTupleEqual((linked_tables[i].source_table.name,
linked_tables[i].source_column.name,
linked_tables[i].target_table.name), v)

def test_get_linked_tables_simple_link(self):
temp_table0 = DynamicTable(name='t0', description='t1',
Expand All @@ -339,17 +339,17 @@ def test_get_linked_tables_simple_link(self):
linked_tables = temp_aligned_table.get_linked_tables()
self.assertEqual(len(linked_tables), 2)
for i, v in enumerate([('my_aligned_table', 'a2', 't0'), ('t1', 'c2', 't0')]):
self.assertTupleEqual((linked_tables[i]['source_table'].name,
linked_tables[i]['source_column'].name,
linked_tables[i]['target_table'].name), v)
self.assertTupleEqual((linked_tables[i].source_table.name,
linked_tables[i].source_column.name,
linked_tables[i].target_table.name), v)
# Since no table ever link to our category temp_table we should only get the link from our
# main table here, in contrast to what happens in the test_get_linked_tables_complex_link case
linked_tables = temp_aligned_table.get_linked_tables()
self.assertEqual(len(linked_tables), 2)
for i, v in enumerate([('my_aligned_table', 'a2', 't0'), ]):
self.assertTupleEqual((linked_tables[i]['source_table'].name,
linked_tables[i]['source_column'].name,
linked_tables[i]['target_table'].name), v)
self.assertTupleEqual((linked_tables[i].source_table.name,
linked_tables[i].source_column.name,
linked_tables[i].target_table.name), v)


class TestHierarchicalTable(TestCase):
Expand Down Expand Up @@ -696,21 +696,21 @@ def test_get_linked_tables(self):
# check level1
temp = self.table_level1.get_linked_tables()
self.assertEqual(len(temp), 2)
self.assertEqual(temp[0]['source_table'].name, self.table_level1.name)
self.assertEqual(temp[0]['source_column'].name, 'child_table_ref1')
self.assertEqual(temp[0]['target_table'].name, self.table_level0_0.name)
self.assertEqual(temp[1]['source_table'].name, self.table_level1.name)
self.assertEqual(temp[1]['source_column'].name, 'child_table_ref2')
self.assertEqual(temp[1]['target_table'].name, self.table_level0_1.name)
self.assertEqual(temp[0].source_table.name, self.table_level1.name)
self.assertEqual(temp[0].source_column.name, 'child_table_ref1')
self.assertEqual(temp[0].target_table.name, self.table_level0_0.name)
self.assertEqual(temp[1].source_table.name, self.table_level1.name)
self.assertEqual(temp[1].source_column.name, 'child_table_ref2')
self.assertEqual(temp[1].target_table.name, self.table_level0_1.name)
# check level2
temp = self.table_level2.get_linked_tables()
self.assertEqual(len(temp), 3)
self.assertEqual(temp[0]['source_table'].name, self.table_level2.name)
self.assertEqual(temp[0]['source_column'].name, 'child_table_ref1')
self.assertEqual(temp[0]['target_table'].name, self.table_level1.name)
self.assertEqual(temp[1]['source_table'].name, self.table_level1.name)
self.assertEqual(temp[1]['source_column'].name, 'child_table_ref1')
self.assertEqual(temp[1]['target_table'].name, self.table_level0_0.name)
self.assertEqual(temp[2]['source_table'].name, self.table_level1.name)
self.assertEqual(temp[2]['source_column'].name, 'child_table_ref2')
self.assertEqual(temp[2]['target_table'].name, self.table_level0_1.name)
self.assertEqual(temp[0].source_table.name, self.table_level2.name)
self.assertEqual(temp[0].source_column.name, 'child_table_ref1')
self.assertEqual(temp[0].target_table.name, self.table_level1.name)
self.assertEqual(temp[1].source_table.name, self.table_level1.name)
self.assertEqual(temp[1].source_column.name, 'child_table_ref1')
self.assertEqual(temp[1].target_table.name, self.table_level0_0.name)
self.assertEqual(temp[2].source_table.name, self.table_level1.name)
self.assertEqual(temp[2].source_column.name, 'child_table_ref2')
self.assertEqual(temp[2].target_table.name, self.table_level0_1.name)

0 comments on commit df31c59

Please sign in to comment.