From df31c59aa396a9920077eb3970d966e9d0f7a75b Mon Sep 17 00:00:00 2001 From: Oliver Ruebel Date: Thu, 29 Jul 2021 16:55:01 -0700 Subject: [PATCH] Update DynamicTableRegion.get_linked_tables to return named tuples (#660) * Update DynamicTableRegion.get_linked_tables to return named tuples rather than dicts * Update changelog Co-authored-by: Ryan Ly --- CHANGELOG.md | 6 +++ src/hdmf/common/table.py | 13 +++-- tests/unit/common/test_linkedtables.py | 66 +++++++++++++------------- 3 files changed, 48 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a1a942ad..2e414a102 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # HDMF Changelog +## HDMF 3.1.1 (July 29, 2021) + +### Fixes +- Updated the new ``DynamicTableRegion.get_linked_tables`` function (added in 3.1.0) to return lists of ``typing.NamedTuple`` + objects rather than lists of dicts. @oruebel (#660) + ## HDMF 3.1.0 (July 29, 2021) ### New features diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 1b3ef0510..27ae57b45 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -5,6 +5,7 @@ import re from collections import OrderedDict +from typing import NamedTuple, Union from warnings import warn import numpy as np @@ -1009,11 +1010,15 @@ def get_linked_tables(self, **kwargs): from this table via foreign DynamicTableColumns included in this table or in any table that can be reached through DynamicTableRegion columns - Returns: List of dicts with the following keys: + Returns: List of NamedTuple objects with: * 'source_table' : The source table containing the DynamicTableRegion column * 'source_column' : The relevant DynamicTableRegion column in the 'source_table' * 'target_table' : The target DynamicTable; same as source_column.table. """ + link_type = NamedTuple('DynamicTableLink', + [('source_table', DynamicTable), + ('source_column', Union[DynamicTableRegion, VectorIndex]), + ('target_table', DynamicTable)]) curr_tables = [self, ] # Set of tables other_tables = getargs('other_tables', kwargs) if other_tables is not None: @@ -1023,9 +1028,9 @@ def get_linked_tables(self, **kwargs): while curr_index < len(curr_tables): for col_index, col in enumerate(curr_tables[curr_index].columns): if isinstance(col, DynamicTableRegion): - foreign_cols.append({'source_table': curr_tables[curr_index], - 'source_column': col, - 'target_table': col.table}) + foreign_cols.append(link_type(source_table=curr_tables[curr_index], + source_column=col, + target_table=col.table)) curr_table_visited = False for t in curr_tables: if t is col.table: diff --git a/tests/unit/common/test_linkedtables.py b/tests/unit/common/test_linkedtables.py index 14db1c471..48a9fd6a0 100644 --- a/tests/unit/common/test_linkedtables.py +++ b/tests/unit/common/test_linkedtables.py @@ -257,13 +257,13 @@ def test_get_linked_tables(self): # check with subcateogries linked_tables = self.aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) - self.assertTupleEqual((linked_tables[0]['source_table'].name, - linked_tables[0]['source_column'].name, - linked_tables[0]['target_table'].name), + self.assertTupleEqual((linked_tables[0].source_table.name, + linked_tables[0].source_column.name, + linked_tables[0].target_table.name), ('category0', 'child_table_ref1', 'level0_0')) - self.assertTupleEqual((linked_tables[1]['source_table'].name, - linked_tables[1]['source_column'].name, - linked_tables[1]['target_table'].name), + self.assertTupleEqual((linked_tables[1].source_table.name, + linked_tables[1].source_column.name, + linked_tables[1].target_table.name), ('category1', 'child_table_ref1', 'level0_1')) def test_get_linked_tables_none(self): @@ -305,17 +305,17 @@ def test_get_linked_tables_complex_link(self): linked_tables = temp_aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) for i, v in enumerate([('my_aligned_table', 'a2', 't1'), ('t1', 'c2', 't0')]): - self.assertTupleEqual((linked_tables[i]['source_table'].name, - linked_tables[i]['source_column'].name, - linked_tables[i]['target_table'].name), v) + self.assertTupleEqual((linked_tables[i].source_table.name, + linked_tables[i].source_column.name, + linked_tables[i].target_table.name), v) # Now, since our main table links to the category table the result should remain the same # even if we ignore the category table linked_tables = temp_aligned_table.get_linked_tables(ignore_category_tables=True) self.assertEqual(len(linked_tables), 2) for i, v in enumerate([('my_aligned_table', 'a2', 't1'), ('t1', 'c2', 't0')]): - self.assertTupleEqual((linked_tables[i]['source_table'].name, - linked_tables[i]['source_column'].name, - linked_tables[i]['target_table'].name), v) + self.assertTupleEqual((linked_tables[i].source_table.name, + linked_tables[i].source_column.name, + linked_tables[i].target_table.name), v) def test_get_linked_tables_simple_link(self): temp_table0 = DynamicTable(name='t0', description='t1', @@ -339,17 +339,17 @@ def test_get_linked_tables_simple_link(self): linked_tables = temp_aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) for i, v in enumerate([('my_aligned_table', 'a2', 't0'), ('t1', 'c2', 't0')]): - self.assertTupleEqual((linked_tables[i]['source_table'].name, - linked_tables[i]['source_column'].name, - linked_tables[i]['target_table'].name), v) + self.assertTupleEqual((linked_tables[i].source_table.name, + linked_tables[i].source_column.name, + linked_tables[i].target_table.name), v) # Since no table ever link to our category temp_table we should only get the link from our # main table here, in contrast to what happens in the test_get_linked_tables_complex_link case linked_tables = temp_aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) for i, v in enumerate([('my_aligned_table', 'a2', 't0'), ]): - self.assertTupleEqual((linked_tables[i]['source_table'].name, - linked_tables[i]['source_column'].name, - linked_tables[i]['target_table'].name), v) + self.assertTupleEqual((linked_tables[i].source_table.name, + linked_tables[i].source_column.name, + linked_tables[i].target_table.name), v) class TestHierarchicalTable(TestCase): @@ -696,21 +696,21 @@ def test_get_linked_tables(self): # check level1 temp = self.table_level1.get_linked_tables() self.assertEqual(len(temp), 2) - self.assertEqual(temp[0]['source_table'].name, self.table_level1.name) - self.assertEqual(temp[0]['source_column'].name, 'child_table_ref1') - self.assertEqual(temp[0]['target_table'].name, self.table_level0_0.name) - self.assertEqual(temp[1]['source_table'].name, self.table_level1.name) - self.assertEqual(temp[1]['source_column'].name, 'child_table_ref2') - self.assertEqual(temp[1]['target_table'].name, self.table_level0_1.name) + self.assertEqual(temp[0].source_table.name, self.table_level1.name) + self.assertEqual(temp[0].source_column.name, 'child_table_ref1') + self.assertEqual(temp[0].target_table.name, self.table_level0_0.name) + self.assertEqual(temp[1].source_table.name, self.table_level1.name) + self.assertEqual(temp[1].source_column.name, 'child_table_ref2') + self.assertEqual(temp[1].target_table.name, self.table_level0_1.name) # check level2 temp = self.table_level2.get_linked_tables() self.assertEqual(len(temp), 3) - self.assertEqual(temp[0]['source_table'].name, self.table_level2.name) - self.assertEqual(temp[0]['source_column'].name, 'child_table_ref1') - self.assertEqual(temp[0]['target_table'].name, self.table_level1.name) - self.assertEqual(temp[1]['source_table'].name, self.table_level1.name) - self.assertEqual(temp[1]['source_column'].name, 'child_table_ref1') - self.assertEqual(temp[1]['target_table'].name, self.table_level0_0.name) - self.assertEqual(temp[2]['source_table'].name, self.table_level1.name) - self.assertEqual(temp[2]['source_column'].name, 'child_table_ref2') - self.assertEqual(temp[2]['target_table'].name, self.table_level0_1.name) + self.assertEqual(temp[0].source_table.name, self.table_level2.name) + self.assertEqual(temp[0].source_column.name, 'child_table_ref1') + self.assertEqual(temp[0].target_table.name, self.table_level1.name) + self.assertEqual(temp[1].source_table.name, self.table_level1.name) + self.assertEqual(temp[1].source_column.name, 'child_table_ref1') + self.assertEqual(temp[1].target_table.name, self.table_level0_0.name) + self.assertEqual(temp[2].source_table.name, self.table_level1.name) + self.assertEqual(temp[2].source_column.name, 'child_table_ref2') + self.assertEqual(temp[2].target_table.name, self.table_level0_1.name)