Skip to content

Commit

Permalink
Add NWBFile copy method (#994)
Browse files Browse the repository at this point in the history
* add LabelledDict default search key

* add neurodata specific name for data_id

* index all objects in an NWB file

* drop neurodata_id rename, rename data_id to object_id

* fix import, add object_id resolution to roundtrip test

* Clean up tests and check object id matches in roundtrip tests

* Make test modular storage not extend testroundtrip

* Test string printing in roundtrip properly

* Set parent attribute instead of using add_child (deprecated)

* Add script to add object ID to files without it

* Add copy() to nwbfile, dynamictable, add settable field to nwbfields
  • Loading branch information
rly authored Jul 31, 2019
1 parent 4cd273e commit 48139e5
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 86 deletions.
1 change: 1 addition & 0 deletions scripts/add_object_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Usage: python add_object_id filename
"""


from h5py import File
from uuid import uuid4
import sys
Expand Down
14 changes: 13 additions & 1 deletion src/pynwb/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def nwbbt_getter(self):
def _setter(cls, nwbfield):
name = nwbfield['name']

if not nwbfield.get('settable', True):
return None

def nwbbt_setter(self, val):
if val is None:
return
Expand Down Expand Up @@ -249,7 +252,7 @@ class NWBContainer(NWBBaseType, Container):
def __init__(self, **kwargs):
call_docval_func(super(NWBContainer, self).__init__, kwargs)

__pconf_allowed_keys = {'name', 'child', 'required_name', 'doc'}
__pconf_allowed_keys = {'name', 'child', 'required_name', 'doc', 'settable'}

@classmethod
def _setter(cls, nwbfield):
Expand Down Expand Up @@ -1376,6 +1379,15 @@ def from_dataframe(cls, **kwargs):

return cls(name=name, id=ids, columns=columns, description=table_description, **kwargs)

def copy(self):
"""
Return a copy of this DynamicTable.
This is useful for linking.
"""
kwargs = dict(name=self.name, id=self.id, columns=self.columns, description=self.description,
colnames=self.colnames)
return self.__class__(**kwargs)


@register_class('DynamicTableRegion', CORE_NAMESPACE)
class DynamicTableRegion(VectorData):
Expand Down
148 changes: 73 additions & 75 deletions src/pynwb/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,11 @@ class NWBFile(MultiContainerInterface):
}
]

__nwbfields__ = ('timestamps_reference_time',
'file_create_date',
__nwbfields__ = ({'name': 'session_description', 'settable': False},
{'name': 'identifier', 'settable': False},
{'name': 'session_start_time', 'settable': False},
{'name': 'timestamps_reference_time', 'settable': False},
{'name': 'file_create_date', 'settable': False},
'experimenter',
'experiment_description',
'session_id',
Expand Down Expand Up @@ -280,62 +283,48 @@ def __init__(self, **kwargs):
pargs, pkwargs = fmt_docval_args(super(NWBFile, self).__init__, kwargs)
pkwargs['name'] = 'root'
super(NWBFile, self).__init__(*pargs, **pkwargs)
self.__session_description = getargs('session_description', kwargs)
self.__identifier = getargs('identifier', kwargs)
self.fields['session_description'] = getargs('session_description', kwargs)
self.fields['identifier'] = getargs('identifier', kwargs)

self.__session_start_time = getargs('session_start_time', kwargs)
if self.__session_start_time.tzinfo is None:
self.__session_start_time = _add_missing_timezone(self.__session_start_time)
self.fields['session_start_time'] = getargs('session_start_time', kwargs)
if self.fields['session_start_time'].tzinfo is None:
self.fields['session_start_time'] = _add_missing_timezone(self.fields['session_start_time'])

self.__timestamps_reference_time = getargs('timestamps_reference_time', kwargs)
if self.__timestamps_reference_time is None:
self.__timestamps_reference_time = self.__session_start_time
elif self.__timestamps_reference_time.tzinfo is None:
self.fields['timestamps_reference_time'] = getargs('timestamps_reference_time', kwargs)
if self.fields['timestamps_reference_time'] is None:
self.fields['timestamps_reference_time'] = self.fields['session_start_time']
elif self.fields['timestamps_reference_time'].tzinfo is None:
raise ValueError("'timestamps_reference_time' must be a timezone-aware datetime object.")

self.__file_create_date = getargs('file_create_date', kwargs)
if self.__file_create_date is None:
self.__file_create_date = datetime.now(tzlocal())
if isinstance(self.__file_create_date, datetime):
self.__file_create_date = [self.__file_create_date]
self.__file_create_date = list(map(_add_missing_timezone, self.__file_create_date))

self.acquisition = getargs('acquisition', kwargs)
self.analysis = getargs('analysis', kwargs)
self.stimulus = getargs('stimulus', kwargs)
self.stimulus_template = getargs('stimulus_template', kwargs)
self.keywords = getargs('keywords', kwargs)

self.processing = getargs('processing', kwargs)
epochs = getargs('epochs', kwargs)
if epochs is not None:
if epochs.name != 'epochs':
raise ValueError("NWBFile.epochs must be named 'epochs'")
self.epochs = epochs
self.epoch_tags = getargs('epoch_tags', kwargs)

trials = getargs('trials', kwargs)
if trials is not None:
self.trials = trials
invalid_times = getargs('invalid_times', kwargs)
if invalid_times is not None:
self.invalid_times = invalid_times
units = getargs('units', kwargs)
if units is not None:
self.units = units

self.electrodes = getargs('electrodes', kwargs)
self.electrode_groups = getargs('electrode_groups', kwargs)
self.devices = getargs('devices', kwargs)
self.ic_electrodes = getargs('ic_electrodes', kwargs)
self.imaging_planes = getargs('imaging_planes', kwargs)
self.ogen_sites = getargs('ogen_sites', kwargs)
self.intervals = getargs('intervals', kwargs)
self.subject = getargs('subject', kwargs)
self.sweep_table = getargs('sweep_table', kwargs)
self.lab_meta_data = getargs('lab_meta_data', kwargs)

recommended = [
self.fields['file_create_date'] = getargs('file_create_date', kwargs)
if self.fields['file_create_date'] is None:
self.fields['file_create_date'] = datetime.now(tzlocal())
if isinstance(self.fields['file_create_date'], datetime):
self.fields['file_create_date'] = [self.fields['file_create_date']]
self.fields['file_create_date'] = list(map(_add_missing_timezone, self.fields['file_create_date']))

fieldnames = [
'acquisition',
'analysis',
'stimulus',
'stimulus_template',
'keywords',
'processing',
'epoch_tags',
'electrodes',
'electrode_groups',
'devices',
'ic_electrodes',
'imaging_planes',
'ogen_sites',
'intervals',
'subject',
'sweep_table',
'lab_meta_data',
'epochs',
'trials',
'invalid_times',
'units',
'experimenter',
'experiment_description',
'session_id',
Expand All @@ -353,7 +342,7 @@ def __init__(self, **kwargs):
'virus',
'stimulus_notes',
]
for attr in recommended:
for attr in fieldnames:
setattr(self, attr, kwargs.get(attr, None))

if getargs('source_script', kwargs) is None and getargs('source_script_file_name', kwargs) is not None:
Expand Down Expand Up @@ -398,26 +387,6 @@ def ec_electrodes(self):
warn("replaced by NWBFile.electrodes", DeprecationWarning)
return self.electrodes

@property
def identifier(self):
return self.__identifier

@property
def session_description(self):
return self.__session_description

@property
def file_create_date(self):
return self.__file_create_date

@property
def session_start_time(self):
return self.__session_start_time

@property
def timestamps_reference_time(self):
return self.__timestamps_reference_time

def __check_epochs(self):
if self.epochs is None:
self.epochs = TimeIntervals('epochs', 'experimental epochs')
Expand Down Expand Up @@ -624,6 +593,35 @@ def add_stimulus_template(self, timeseries):
self._add_stimulus_template_internal(timeseries)
self._update_sweep_table(timeseries)

def copy(self):
"""
Shallow copy of an NWB file.
Useful for linking across files.
"""
kwargs = self.fields.copy()
for key in self.fields:
if isinstance(self.fields[key], LabelledDict):
kwargs[key] = list(self.fields[key].values())

# HDF5 object references cannot point to objects external to the file. Both DynamicTables such as TimeIntervals
# contain such object references and types such as ElectricalSeries contain references to DynamicTables.
# Below, copy the table and link to the columns so that object references work.
fields_to_copy = ['electrodes', 'epochs', 'trials', 'units', 'subject', 'sweep_table', 'invalid_times']
for field in fields_to_copy:
if field in kwargs:
if isinstance(self.fields[field], DynamicTable):
kwargs[field] = self.fields[field].copy()
else:
warn('Cannot copy child of NWBFile that is not a DynamicTable.')

# handle dictionaries of DynamicTables
dt_to_copy = ['scratch', 'intervals']
for dt in dt_to_copy:
if dt in kwargs:
kwargs[dt] = [v.copy() if isinstance(v, DynamicTable) else v for v in kwargs[dt]]

return NWBFile(**kwargs)


def _add_missing_timezone(date):
"""
Expand Down
45 changes: 41 additions & 4 deletions tests/integration/ui_write/test_modular_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def remove_file(self, path):
os.remove(path)

def setUp(self):
self.__manager = get_manager()
self.start_time = datetime(1971, 1, 1, 12, tzinfo=tzutc())

self.data = np.arange(2000).reshape((2, 1000))
Expand All @@ -35,8 +34,8 @@ def setUp(self):
timestamps=self.timestamps
)

self.data_filename = 'test_time_series_modular_data.nwb'
self.link_filename = 'test_time_series_modular_link.nwb'
self.data_filename = os.path.join(os.getcwd(), 'test_time_series_modular_data.nwb')
self.link_filename = os.path.join(os.getcwd(), 'test_time_series_modular_link.nwb')

self.read_container = None
self.link_read_io = None
Expand All @@ -57,8 +56,8 @@ def tearDown(self):
if os.name == 'nt':
gc.collect()

self.remove_file(self.data_filename)
self.remove_file(self.link_filename)
self.remove_file(self.data_filename)

def roundtripContainer(self):
# create and write data file
Expand Down Expand Up @@ -117,6 +116,44 @@ def test_roundtrip(self):
self.assertContainerEqual(self.read_container, self.container)
self.validate()

def test_link_root(self):
# create and write data file
data_file = NWBFile(
session_description='a test file',
identifier='data_file',
session_start_time=self.start_time
)
data_file.add_acquisition(self.container)

with HDF5IO(self.data_filename, 'w', manager=get_manager()) as data_write_io:
data_write_io.write(data_file)

# read data file
manager = get_manager()
with HDF5IO(self.data_filename, 'r', manager=manager) as data_read_io:
data_file_obt = data_read_io.read()

link_file = NWBFile(
session_description='a test file',
identifier='link_file',
session_start_time=self.start_time
)
link_container = data_file_obt.acquisition[self.container.name]
link_file.add_acquisition(link_container)
self.assertIs(link_container.parent, data_file_obt)

with HDF5IO(self.link_filename, 'w', manager=manager) as link_write_io:
link_write_io.write(link_file)

# read the link file, check container sources
with HDF5IO(self.link_filename, 'r+', manager=get_manager()) as link_file_reader:
read_nwbfile = link_file_reader.read()
self.assertNotEqual(read_nwbfile.acquisition[self.container.name].container_source,
read_nwbfile.container_source)
self.assertEqual(read_nwbfile.acquisition[self.container.name].container_source,
self.data_filename)
self.assertEqual(read_nwbfile.container_source, self.link_filename)

def validate(self):
filenames = [self.data_filename, self.link_filename]
for fn in filenames:
Expand Down
15 changes: 10 additions & 5 deletions tests/unit/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,19 +327,24 @@ def test_print_file(self):
nwbfile.add_acquisition(ts)
nwbfile.add_acquisition(ts2)
nwbfile.add_epoch(start_time=1.0, stop_time=10.0, tags=['tag1', 'tag2'])
self.assertEqual(str(nwbfile),
"""
root <class 'pynwb.file.NWBFile'>
self.assertRegex(str(nwbfile),
r"""
root <class 'pynwb\.file\.NWBFile'>
Fields:
acquisition: {
name <class 'pynwb.base.TimeSeries'>,
name2 <class 'pynwb.base.TimeSeries'>
name <class 'pynwb\.base\.TimeSeries'>,
name2 <class 'pynwb\.base\.TimeSeries'>
}
epoch_tags: {
tag1,
tag2
}
epochs: epochs <class 'pynwb.epoch.TimeIntervals'>
file_create_date: \[datetime.datetime\(.*\)\]
identifier: identifier
session_description: session_description
session_start_time: .*
timestamps_reference_time: .*
""")


Expand Down
Loading

0 comments on commit 48139e5

Please sign in to comment.