Skip to content

Commit

Permalink
Support Object IDs in NWBFile (#991)
Browse files Browse the repository at this point in the history
* add LabelledDict default search key

* add neurodata specific name for data_id

* index all objects in an NWB file

* drop neurodata_id rename, rename data_id to object_id

* fix import, add object_id resolution to roundtrip test

* Clean up tests and check object id matches in roundtrip tests

* Make test modular storage not extend testroundtrip

* Test string printing in roundtrip properly

* Set parent attribute instead of using add_child (deprecated)

* Do not save object_ids in map if none. Warn if None.

* Add script to add object ID to files without it

* Remove unused set_parents function

* Do not set parent when it already exists so that a link can be created

* Fix flake8

* Fix for when parent is still a Proxy

* Fix printing of DataIO wrapped dataset that has been closed

* Fix modular storage test errors

* Enhance modular storage test, requires latest hdmf
  • Loading branch information
ajtritt authored Jul 30, 2019
1 parent 6bd603f commit 4cd273e
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 60 deletions.
28 changes: 28 additions & 0 deletions scripts/add_object_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
This script adds an 'object_id' attribute to each neurodata_type of an hdf5 file that was written before object IDs
existed. Specifically, it traverses through the hierarchy of objects in the file and sets the 'object_id' attribute
to a UUID4 string on each group, dataset, and link that has a 'neurodata_type' attribute and does not have an
'object_id' attribute.
Usage: python add_object_id filename
"""

from h5py import File
from uuid import uuid4
import sys


def add_uuid(name, obj):
if 'neurodata_type' in obj.attrs and 'object_id' not in obj.attrs:
obj.attrs['object_id'] = str(uuid4())
print('Adding uuid4 %s to %s' % (obj.attrs['object_id'], str(obj)))


def main():
filename = sys.argv[1]
with File(filename, 'a') as f:
f.visititems(add_uuid)


if __name__ == '__main__':
main()
39 changes: 18 additions & 21 deletions src/pynwb/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

from hdmf.utils import docval, getargs, ExtenderMeta, call_docval_func, popargs, get_docval, fmt_docval_args, pystr
from hdmf.data_utils import DataIO
from hdmf import Container, Data, DataRegion, get_region_slicer

from . import CORE_NAMESPACE, register_class
Expand All @@ -13,29 +14,18 @@ def _not_parent(arg):
return arg['name'] != 'parent'


def set_parents(container, parent):
if isinstance(container, list):
for c in container:
if c.parent is None:
c.parent = parent
ret = container
else:
ret = [container]
if container.parent is None:
container.parent = parent
return ret


class LabelledDict(dict):
'''
A dict wrapper class for aggregating Timeseries
from the standard locations
'''

@docval({'name': 'label', 'type': str, 'doc': 'the TimeSeries type ('})
@docval({'name': 'label', 'type': str, 'doc': 'the label on this dictionary'},
{'name': 'def_key_name', 'type': str, 'doc': 'the default key name', 'default': 'name'})
def __init__(self, **kwargs):
label = getargs('label', kwargs)
label, def_key_name = getargs('label', 'def_key_name', kwargs)
self.__label = label
self.__defkey = def_key_name

@property
def label(self):
Expand All @@ -47,7 +37,7 @@ def __getitem__(self, args):
key, val = args.split("==")
key = key.strip()
val = val.strip()
if key != 'name':
if key != self.__defkey:
ret = list()
for item in self.values():
if getattr(item, key, None) == val:
Expand Down Expand Up @@ -165,7 +155,7 @@ def __repr__(self):
template = "\n{} {}\nFields:\n""".format(getattr(self, 'name'), type(self))
for k in sorted(self.fields): # sorted to enable tests
v = self.fields[k]
if not hasattr(v, '__len__') or len(v) > 0:
if isinstance(v, DataIO) or not hasattr(v, '__len__') or len(v) > 0:
template += " {}: {}\n".format(k, self.__smart_str(v, 1))
return template

Expand Down Expand Up @@ -295,7 +285,10 @@ def nwbdi_setter(self, val):
else:
val = [val]
for v in val:
self.add_child(v)
if not isinstance(v.parent, Container):
v.parent = self
# else, the ObjectMapper will create a link from self (parent) to v (child with existing
# parent)

ret.append(nwbdi_setter)
return ret[-1]
Expand Down Expand Up @@ -767,7 +760,9 @@ def _func(self, **kwargs):
containers = container
d = getattr(self, attr_name)
for tmp in containers:
self.add_child(tmp)
if not isinstance(tmp.parent, Container):
tmp.parent = self
# else, the ObjectMapper will create a link from self (parent) to tmp (child with existing parent)
if tmp.name in d:
msg = "'%s' already exists in '%s'" % (tmp.name, self.name)
raise ValueError(msg)
Expand Down Expand Up @@ -1199,7 +1194,7 @@ def add_column(self, **kwargs):
ckwargs['table'] = table

col = cls(**ckwargs)
self.add_child(col)
col.parent = self
columns = [col]

# Add index if it's been specified
Expand All @@ -1215,7 +1210,9 @@ def add_column(self, **kwargs):
raise ValueError("cannot pass non-empty index with empty data to index")
col_index = VectorIndex(name + "_index", index, col)
columns.insert(0, col_index)
self.add_child(col_index)
if not isinstance(col_index.parent, Container):
col_index.parent = self
# else, the ObjectMapper will create a link from self (parent) to col_index (child with existing parent)
col = col_index
self.__indices[col_index.name] = col_index

Expand Down
16 changes: 15 additions & 1 deletion src/pynwb/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from .ophys import ImagingPlane
from .ogen import OptogeneticStimulusSite
from .misc import Units
from .core import NWBContainer, NWBDataInterface, MultiContainerInterface, DynamicTable, DynamicTableRegion
from .core import NWBContainer, NWBDataInterface, MultiContainerInterface, DynamicTable, DynamicTableRegion,\
LabelledDict


def _not_parent(arg):
Expand Down Expand Up @@ -358,17 +359,30 @@ def __init__(self, **kwargs):
if getargs('source_script', kwargs) is None and getargs('source_script_file_name', kwargs) is not None:
raise ValueError("'source_script' cannot be None when 'source_script_file_name' is set")

self.__obj = None

def all_children(self):
stack = [self]
ret = list()
self.__obj = LabelledDict(label='all_objects', def_key_name='object_id')
while len(stack):
n = stack.pop()
ret.append(n)
if n.object_id is not None:
self.__obj[n.object_id] = n
else:
warn('%s "%s" does not have an object_id' % (n.neurodata_type, n.name))
if hasattr(n, 'children'):
for c in n.children:
stack.append(c)
return ret

@property
def objects(self):
if self.__obj is None:
self.all_children()
return self.__obj

@property
def modules(self):
warn("replaced by NWBFile.processing", DeprecationWarning)
Expand Down
15 changes: 3 additions & 12 deletions src/pynwb/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,17 @@ class BaseStorageOverride(object):

@classmethod
def type_key(cls):
''' Get the key used to store data type on an instance
Override this method to use a different name for 'data_type'
'''
''' Get the key used to store data type on an instance'''
return cls.__type_key

@classmethod
def inc_key(cls):
''' Get the key used to define a data_type include.
Override this method to use a different keyword for 'data_type_inc'
'''
''' Get the key used to define a data_type include.'''
return cls.__inc_key

@classmethod
def def_key(cls):
''' Get the key used to define a data_type definition.
Override this method to use a different keyword for 'data_type_def'
'''
''' Get the key used to define a data_type definition.'''
return cls.__def_key

@property
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/ui_write/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class TestMapRoundTrip(TestMapNWBContainer):
def setUp(self):
super(TestMapRoundTrip, self).setUp()
self.container = self.setUpContainer()
self.object_id = self.container.object_id
self.start_time = datetime(1971, 1, 1, 12, tzinfo=tzutc())
self.create_date = datetime(2018, 4, 15, 12, tzinfo=tzlocal())
self.container_type = self.container.__class__.__name__
Expand Down Expand Up @@ -182,8 +183,10 @@ def roundtripContainer(self, cache_spec=False):
def test_roundtrip(self):
self.read_container = self.roundtripContainer()
# make sure we get a completely new object
str(self.container) # added as a test to make sure printing works
self.assertIsNotNone(str(self.container)) # added as a test to make sure printing works
self.assertIsNotNone(str(self.read_container))
self.assertNotEqual(id(self.container), id(self.read_container))
self.assertIs(self.read_nwbfile.objects[self.container.object_id], self.read_container)
self.assertContainerEqual(self.read_container, self.container)
self.validate()

Expand Down
6 changes: 2 additions & 4 deletions tests/integration/ui_write/test_icephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def setUpSweepTable(self):
starting_time=123.6, rate=10e3, electrode=self.elec, gain=0.126,
stimulus_description="gotcha ya!", sweep_number=4711)
self.sweep_table = SweepTable(name='sweep_table')
self.sweep_table.add_entry(self.pcs)

def addContainer(self, nwbfile):
nwbfile.sweep_table = self.sweep_table
nwbfile.add_device(self.device)
nwbfile.add_ic_electrode(self.elec)
nwbfile.add_acquisition(self.pcs)
Expand Down Expand Up @@ -306,12 +306,10 @@ def setUpSweepTable(self):
stimulus_description="gotcha ya!", sweep_number=4712)

self.sweep_table = SweepTable(name='sweep_table')
self.sweep_table.add_entry(self.pcs1)
self.sweep_table.add_entry(self.pcs2a)
self.sweep_table.add_entry(self.pcs2b)

def addContainer(self, nwbfile):
''' Should take an NWBFile object and add the SweepTable container to it '''
nwbfile.sweep_table = self.sweep_table
nwbfile.add_device(self.device)
nwbfile.add_ic_electrode(self.elec)

Expand Down
70 changes: 49 additions & 21 deletions tests/integration/ui_write/test_modular_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from . import base


class TestTimeSeriesModular(base.TestMapRoundTrip):
class TestTimeSeriesModular(base.TestMapNWBContainer):

_required_tests = ('test_roundtrip',)

Expand All @@ -38,46 +38,84 @@ def setUp(self):
self.data_filename = 'test_time_series_modular_data.nwb'
self.link_filename = 'test_time_series_modular_link.nwb'

self.read_container = None
self.link_read_io = None
self.data_read_io = None

def tearDown(self):
self.read_container.data.file.close()
self.read_container.timestamps.file.close()
if self.read_container:
self.read_container.data.file.close()
self.read_container.timestamps.file.close()
if self.link_read_io:
self.link_read_io.close()
if self.data_read_io:
self.data_read_io.close()

# necessary to remove all references to the file and garbage
# collect on windows in order to be able to truncate/overwrite
# the file later. see pynwb GH issue #975
if os.name == 'nt':
gc.collect()

self.remove_file(self.data_filename)
self.remove_file(self.link_filename)

def roundtripContainer(self):
# create and write data file
data_file = NWBFile(
session_description='a test file',
identifier='data_file',
session_start_time=self.start_time
)
data_file.add_acquisition(self.container)

with HDF5IO(self.data_filename, 'w', manager=get_manager()) as self.data_write_io:
self.data_write_io.write(data_file)
with HDF5IO(self.data_filename, 'w', manager=get_manager()) as data_write_io:
data_write_io.write(data_file)

# read data file
with HDF5IO(self.data_filename, 'r', manager=get_manager()) as self.data_read_io:
data_file_obt = self.data_read_io.read()

# write "link file" with timeseries.data that is an external link to the timeseries in "data file"
# also link timeseries.timestamps.data to the timeseries.timestamps in "data file"
with HDF5IO(self.link_filename, 'w', manager=get_manager()) as link_write_io:
link_file = NWBFile(
session_description='a test file',
identifier='link_file',
session_start_time=self.start_time
)
link_file.add_acquisition(TimeSeries(
self.link_container = TimeSeries(
name='test_mod_ts',
unit='V',
data=data_file_obt.get_acquisition('data_ts'),
data=data_file_obt.get_acquisition('data_ts'), # test direct link
timestamps=H5DataIO(
data=data_file_obt.get_acquisition('data_ts').timestamps,
link_data=True
link_data=True # test with setting link data
)
))
)
link_file.add_acquisition(self.link_container)
link_write_io.write(link_file)

with HDF5IO(self.link_filename, 'r', manager=get_manager()) as self.link_file_reader:
return self.getContainer(self.link_file_reader.read())
# note that self.link_container contains a link to a dataset that is now closed

# read the link file
self.link_read_io = HDF5IO(self.link_filename, 'r', manager=get_manager())
self.read_nwbfile = self.link_read_io.read()
return self.getContainer(self.read_nwbfile)

def test_roundtrip(self):
self.read_container = self.roundtripContainer()

# make sure we get a completely new object
self.assertIsNotNone(str(self.container)) # added as a test to make sure printing works
self.assertIsNotNone(str(self.link_container))
self.assertIsNotNone(str(self.read_container))
self.assertFalse(self.link_container.timestamps.valid)
self.assertTrue(self.read_container.timestamps.id.valid)
self.assertNotEqual(id(self.link_container), id(self.read_container))
self.assertIs(self.read_nwbfile.objects[self.link_container.object_id], self.read_container)
self.assertContainerEqual(self.read_container, self.container)
self.validate()

def validate(self):
filenames = [self.data_filename, self.link_filename]
Expand All @@ -89,15 +127,5 @@ def validate(self):
for err in errors:
raise Exception(err)

# necessary to remove all references to the file and garbage
# collect on windows in order to be able to truncate/overwrite
# the file later. see pynwb GH issue #975
if os.name == 'nt':
del io
gc.collect()

def getContainer(self, nwbfile):
return nwbfile.get_acquisition('test_mod_ts')

def addContainer(self, nwbfile):
pass

0 comments on commit 4cd273e

Please sign in to comment.