diff --git a/scripts/add_object_id.py b/scripts/add_object_id.py new file mode 100644 index 000000000..9453e58a4 --- /dev/null +++ b/scripts/add_object_id.py @@ -0,0 +1,28 @@ +""" +This script adds an 'object_id' attribute to each neurodata_type of an hdf5 file that was written before object IDs +existed. Specifically, it traverses through the hierarchy of objects in the file and sets the 'object_id' attribute +to a UUID4 string on each group, dataset, and link that has a 'neurodata_type' attribute and does not have an +'object_id' attribute. + +Usage: python add_object_id filename +""" + +from h5py import File +from uuid import uuid4 +import sys + + +def add_uuid(name, obj): + if 'neurodata_type' in obj.attrs and 'object_id' not in obj.attrs: + obj.attrs['object_id'] = str(uuid4()) + print('Adding uuid4 %s to %s' % (obj.attrs['object_id'], str(obj))) + + +def main(): + filename = sys.argv[1] + with File(filename, 'a') as f: + f.visititems(add_uuid) + + +if __name__ == '__main__': + main() diff --git a/src/pynwb/core.py b/src/pynwb/core.py index 1f7d83392..0ffe50725 100644 --- a/src/pynwb/core.py +++ b/src/pynwb/core.py @@ -3,6 +3,7 @@ import pandas as pd from hdmf.utils import docval, getargs, ExtenderMeta, call_docval_func, popargs, get_docval, fmt_docval_args, pystr +from hdmf.data_utils import DataIO from hdmf import Container, Data, DataRegion, get_region_slicer from . import CORE_NAMESPACE, register_class @@ -13,29 +14,18 @@ def _not_parent(arg): return arg['name'] != 'parent' -def set_parents(container, parent): - if isinstance(container, list): - for c in container: - if c.parent is None: - c.parent = parent - ret = container - else: - ret = [container] - if container.parent is None: - container.parent = parent - return ret - - class LabelledDict(dict): ''' A dict wrapper class for aggregating Timeseries from the standard locations ''' - @docval({'name': 'label', 'type': str, 'doc': 'the TimeSeries type ('}) + @docval({'name': 'label', 'type': str, 'doc': 'the label on this dictionary'}, + {'name': 'def_key_name', 'type': str, 'doc': 'the default key name', 'default': 'name'}) def __init__(self, **kwargs): - label = getargs('label', kwargs) + label, def_key_name = getargs('label', 'def_key_name', kwargs) self.__label = label + self.__defkey = def_key_name @property def label(self): @@ -47,7 +37,7 @@ def __getitem__(self, args): key, val = args.split("==") key = key.strip() val = val.strip() - if key != 'name': + if key != self.__defkey: ret = list() for item in self.values(): if getattr(item, key, None) == val: @@ -165,7 +155,7 @@ def __repr__(self): template = "\n{} {}\nFields:\n""".format(getattr(self, 'name'), type(self)) for k in sorted(self.fields): # sorted to enable tests v = self.fields[k] - if not hasattr(v, '__len__') or len(v) > 0: + if isinstance(v, DataIO) or not hasattr(v, '__len__') or len(v) > 0: template += " {}: {}\n".format(k, self.__smart_str(v, 1)) return template @@ -295,7 +285,10 @@ def nwbdi_setter(self, val): else: val = [val] for v in val: - self.add_child(v) + if not isinstance(v.parent, Container): + v.parent = self + # else, the ObjectMapper will create a link from self (parent) to v (child with existing + # parent) ret.append(nwbdi_setter) return ret[-1] @@ -767,7 +760,9 @@ def _func(self, **kwargs): containers = container d = getattr(self, attr_name) for tmp in containers: - self.add_child(tmp) + if not isinstance(tmp.parent, Container): + tmp.parent = self + # else, the ObjectMapper will create a link from self (parent) to tmp (child with existing parent) if tmp.name in d: msg = "'%s' already exists in '%s'" % (tmp.name, self.name) raise ValueError(msg) @@ -1199,7 +1194,7 @@ def add_column(self, **kwargs): ckwargs['table'] = table col = cls(**ckwargs) - self.add_child(col) + col.parent = self columns = [col] # Add index if it's been specified @@ -1215,7 +1210,9 @@ def add_column(self, **kwargs): raise ValueError("cannot pass non-empty index with empty data to index") col_index = VectorIndex(name + "_index", index, col) columns.insert(0, col_index) - self.add_child(col_index) + if not isinstance(col_index.parent, Container): + col_index.parent = self + # else, the ObjectMapper will create a link from self (parent) to col_index (child with existing parent) col = col_index self.__indices[col_index.name] = col_index diff --git a/src/pynwb/file.py b/src/pynwb/file.py index 3fc953e3e..ebe7866f2 100644 --- a/src/pynwb/file.py +++ b/src/pynwb/file.py @@ -18,7 +18,8 @@ from .ophys import ImagingPlane from .ogen import OptogeneticStimulusSite from .misc import Units -from .core import NWBContainer, NWBDataInterface, MultiContainerInterface, DynamicTable, DynamicTableRegion +from .core import NWBContainer, NWBDataInterface, MultiContainerInterface, DynamicTable, DynamicTableRegion,\ + LabelledDict def _not_parent(arg): @@ -358,17 +359,30 @@ def __init__(self, **kwargs): if getargs('source_script', kwargs) is None and getargs('source_script_file_name', kwargs) is not None: raise ValueError("'source_script' cannot be None when 'source_script_file_name' is set") + self.__obj = None + def all_children(self): stack = [self] ret = list() + self.__obj = LabelledDict(label='all_objects', def_key_name='object_id') while len(stack): n = stack.pop() ret.append(n) + if n.object_id is not None: + self.__obj[n.object_id] = n + else: + warn('%s "%s" does not have an object_id' % (n.neurodata_type, n.name)) if hasattr(n, 'children'): for c in n.children: stack.append(c) return ret + @property + def objects(self): + if self.__obj is None: + self.all_children() + return self.__obj + @property def modules(self): warn("replaced by NWBFile.processing", DeprecationWarning) diff --git a/src/pynwb/spec.py b/src/pynwb/spec.py index e248469a4..7ad3155c9 100644 --- a/src/pynwb/spec.py +++ b/src/pynwb/spec.py @@ -73,26 +73,17 @@ class BaseStorageOverride(object): @classmethod def type_key(cls): - ''' Get the key used to store data type on an instance - - Override this method to use a different name for 'data_type' - ''' + ''' Get the key used to store data type on an instance''' return cls.__type_key @classmethod def inc_key(cls): - ''' Get the key used to define a data_type include. - - Override this method to use a different keyword for 'data_type_inc' - ''' + ''' Get the key used to define a data_type include.''' return cls.__inc_key @classmethod def def_key(cls): - ''' Get the key used to define a data_type definition. - - Override this method to use a different keyword for 'data_type_def' - ''' + ''' Get the key used to define a data_type definition.''' return cls.__def_key @property diff --git a/tests/integration/ui_write/base.py b/tests/integration/ui_write/base.py index d5d1e315b..01facf287 100644 --- a/tests/integration/ui_write/base.py +++ b/tests/integration/ui_write/base.py @@ -144,6 +144,7 @@ class TestMapRoundTrip(TestMapNWBContainer): def setUp(self): super(TestMapRoundTrip, self).setUp() self.container = self.setUpContainer() + self.object_id = self.container.object_id self.start_time = datetime(1971, 1, 1, 12, tzinfo=tzutc()) self.create_date = datetime(2018, 4, 15, 12, tzinfo=tzlocal()) self.container_type = self.container.__class__.__name__ @@ -182,8 +183,10 @@ def roundtripContainer(self, cache_spec=False): def test_roundtrip(self): self.read_container = self.roundtripContainer() # make sure we get a completely new object - str(self.container) # added as a test to make sure printing works + self.assertIsNotNone(str(self.container)) # added as a test to make sure printing works + self.assertIsNotNone(str(self.read_container)) self.assertNotEqual(id(self.container), id(self.read_container)) + self.assertIs(self.read_nwbfile.objects[self.container.object_id], self.read_container) self.assertContainerEqual(self.read_container, self.container) self.validate() diff --git a/tests/integration/ui_write/test_icephys.py b/tests/integration/ui_write/test_icephys.py index 380bac797..d1df40666 100644 --- a/tests/integration/ui_write/test_icephys.py +++ b/tests/integration/ui_write/test_icephys.py @@ -153,9 +153,9 @@ def setUpSweepTable(self): starting_time=123.6, rate=10e3, electrode=self.elec, gain=0.126, stimulus_description="gotcha ya!", sweep_number=4711) self.sweep_table = SweepTable(name='sweep_table') - self.sweep_table.add_entry(self.pcs) def addContainer(self, nwbfile): + nwbfile.sweep_table = self.sweep_table nwbfile.add_device(self.device) nwbfile.add_ic_electrode(self.elec) nwbfile.add_acquisition(self.pcs) @@ -306,12 +306,10 @@ def setUpSweepTable(self): stimulus_description="gotcha ya!", sweep_number=4712) self.sweep_table = SweepTable(name='sweep_table') - self.sweep_table.add_entry(self.pcs1) - self.sweep_table.add_entry(self.pcs2a) - self.sweep_table.add_entry(self.pcs2b) def addContainer(self, nwbfile): ''' Should take an NWBFile object and add the SweepTable container to it ''' + nwbfile.sweep_table = self.sweep_table nwbfile.add_device(self.device) nwbfile.add_ic_electrode(self.elec) diff --git a/tests/integration/ui_write/test_modular_storage.py b/tests/integration/ui_write/test_modular_storage.py index f033a7c02..f4a8d1833 100644 --- a/tests/integration/ui_write/test_modular_storage.py +++ b/tests/integration/ui_write/test_modular_storage.py @@ -13,7 +13,7 @@ from . import base -class TestTimeSeriesModular(base.TestMapRoundTrip): +class TestTimeSeriesModular(base.TestMapNWBContainer): _required_tests = ('test_roundtrip',) @@ -38,14 +38,30 @@ def setUp(self): self.data_filename = 'test_time_series_modular_data.nwb' self.link_filename = 'test_time_series_modular_link.nwb' + self.read_container = None + self.link_read_io = None + self.data_read_io = None + def tearDown(self): - self.read_container.data.file.close() - self.read_container.timestamps.file.close() + if self.read_container: + self.read_container.data.file.close() + self.read_container.timestamps.file.close() + if self.link_read_io: + self.link_read_io.close() + if self.data_read_io: + self.data_read_io.close() + + # necessary to remove all references to the file and garbage + # collect on windows in order to be able to truncate/overwrite + # the file later. see pynwb GH issue #975 + if os.name == 'nt': + gc.collect() self.remove_file(self.data_filename) self.remove_file(self.link_filename) def roundtripContainer(self): + # create and write data file data_file = NWBFile( session_description='a test file', identifier='data_file', @@ -53,31 +69,53 @@ def roundtripContainer(self): ) data_file.add_acquisition(self.container) - with HDF5IO(self.data_filename, 'w', manager=get_manager()) as self.data_write_io: - self.data_write_io.write(data_file) + with HDF5IO(self.data_filename, 'w', manager=get_manager()) as data_write_io: + data_write_io.write(data_file) + # read data file with HDF5IO(self.data_filename, 'r', manager=get_manager()) as self.data_read_io: data_file_obt = self.data_read_io.read() + # write "link file" with timeseries.data that is an external link to the timeseries in "data file" + # also link timeseries.timestamps.data to the timeseries.timestamps in "data file" with HDF5IO(self.link_filename, 'w', manager=get_manager()) as link_write_io: link_file = NWBFile( session_description='a test file', identifier='link_file', session_start_time=self.start_time ) - link_file.add_acquisition(TimeSeries( + self.link_container = TimeSeries( name='test_mod_ts', unit='V', - data=data_file_obt.get_acquisition('data_ts'), + data=data_file_obt.get_acquisition('data_ts'), # test direct link timestamps=H5DataIO( data=data_file_obt.get_acquisition('data_ts').timestamps, - link_data=True + link_data=True # test with setting link data ) - )) + ) + link_file.add_acquisition(self.link_container) link_write_io.write(link_file) - with HDF5IO(self.link_filename, 'r', manager=get_manager()) as self.link_file_reader: - return self.getContainer(self.link_file_reader.read()) + # note that self.link_container contains a link to a dataset that is now closed + + # read the link file + self.link_read_io = HDF5IO(self.link_filename, 'r', manager=get_manager()) + self.read_nwbfile = self.link_read_io.read() + return self.getContainer(self.read_nwbfile) + + def test_roundtrip(self): + self.read_container = self.roundtripContainer() + + # make sure we get a completely new object + self.assertIsNotNone(str(self.container)) # added as a test to make sure printing works + self.assertIsNotNone(str(self.link_container)) + self.assertIsNotNone(str(self.read_container)) + self.assertFalse(self.link_container.timestamps.valid) + self.assertTrue(self.read_container.timestamps.id.valid) + self.assertNotEqual(id(self.link_container), id(self.read_container)) + self.assertIs(self.read_nwbfile.objects[self.link_container.object_id], self.read_container) + self.assertContainerEqual(self.read_container, self.container) + self.validate() def validate(self): filenames = [self.data_filename, self.link_filename] @@ -89,15 +127,5 @@ def validate(self): for err in errors: raise Exception(err) - # necessary to remove all references to the file and garbage - # collect on windows in order to be able to truncate/overwrite - # the file later. see pynwb GH issue #975 - if os.name == 'nt': - del io - gc.collect() - def getContainer(self, nwbfile): return nwbfile.get_acquisition('test_mod_ts') - - def addContainer(self, nwbfile): - pass