Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
mavaylon1 committed Sep 20, 2023
1 parent 04b6e2b commit e170307
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 30 deletions.
31 changes: 11 additions & 20 deletions src/hdmf_zarr/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,6 @@ def __setup_chunked_dataset__(cls, parent, name, data, options=None):
returns='the Zarr array that was created', rtype=Array)
def write_dataset(self, **kwargs): # noqa: C901
parent, builder, link_data, exhaust_dci = getargs('parent', 'builder', 'link_data', 'exhaust_dci', kwargs)
# breakpoint()
force_data = getargs('force_data', kwargs)
if self.get_written(builder):
return None
Expand Down Expand Up @@ -1085,9 +1084,7 @@ def __read_group(self, zarr_obj, name=None):
for sub_name, sub_group in zarr_obj.groups():
sub_builder = self.__read_group(sub_group, sub_name)
ret.set_group(sub_builder)
# breakpoint()
# read sub datasets
# breakpoint()
for sub_name, sub_array in zarr_obj.arrays():
sub_builder = self.__read_dataset(sub_array, sub_name)
ret.set_dataset(sub_builder)
Expand Down Expand Up @@ -1124,9 +1121,7 @@ def __read_links(self, zarr_obj, parent):
parent.set_link(link_builder)

def __read_dataset(self, zarr_obj, name):
# breakpoint()
ret = self.__get_built(zarr_obj)
# breakpoint()
if ret is not None:
return ret

Expand Down Expand Up @@ -1154,30 +1149,27 @@ def __read_dataset(self, zarr_obj, name):
if dtype == 'scalar':
data = zarr_obj[0]

obj_refs = False
reg_refs = False
has_reference = False
if isinstance(dtype, list):
# compound data type
# Check compound dataset where one of the subsets contains references
has_reference = False
for i, dts in enumerate(dtype):
if dts['dtype'] == 'object': # check items for object reference
"""
This is a compound dataset where one of the subsets contains references (one or more)
"""
has_reference = True
break
elif dts['dtype'] == 'region':
has_reference = True
break
# TODO: Region reference not supported
# elif dts['dtype'] == 'region':
# has_reference = True
# break
retrieved_dtypes = [dtype_dict['dtype'] for dtype_dict in dtype]
data = BuilderZarrTableDataset(zarr_obj, self, retrieved_dtypes)
if has_reference:
data = BuilderZarrTableDataset(zarr_obj, self, retrieved_dtypes)
elif self.__is_ref(dtype):
# reference array
has_reference = True #TODO: REMOVE
if dtype == 'object': # wrap with dataset ref
data = BuilderZarrReferenceDataset(data, self)
elif dtype == 'region':
reg_refs = True #TODO: Region reference not wrapped yet
# TODO: Region reference not wrapped yet
# elif dtype == 'region':
# reg_refs = True

kwargs['data'] = data
if name is None:
Expand All @@ -1194,7 +1186,6 @@ def __read_attrs(self, zarr_obj):
if k not in self.__reserve_attribute:
v = zarr_obj.attrs[k]
if isinstance(v, dict) and 'zarr_dtype' in v:
# TODO Is this the correct way to resolve references?
if v['zarr_dtype'] == 'object':
target_name, target_zarr_obj = self.resolve_ref(v['value'])
if isinstance(target_zarr_obj, zarr.hierarchy.Group):
Expand Down
10 changes: 0 additions & 10 deletions tests/unit/test_io_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,10 @@ def test_export_roundtrip(self):
self.filenames.append(write_path if isinstance(write_path, str) else write_path.path)
self.filenames.append(export_path if isinstance(export_path, str) else export_path.path)
# roundtrip the container
# breakpoint()
exported_container = self.roundtripExportContainer(
container=container,
write_path=write_path,
export_path=export_path)
# breakpoint()
if self.REFERENCES:
if self.TARGET_FORMAT == "H5":
num_bazs = 10
Expand All @@ -182,10 +180,8 @@ def test_export_roundtrip(self):
num_bazs = 10
for i in range(num_bazs):
baz_name = 'baz%d' % i
# breakpoint()
self.assertEqual(exported_container.baz_data.data.__class__.__name__, 'ContainerZarrReferenceDataset')
self.assertIs(exported_container.baz_data.data[i], exported_container.bazs[baz_name])
# breakpoint()
# assert that the roundtrip worked correctly
message = "Using: write_path=%s, export_path=%s" % (str(write_path), str(export_path))
self.assertIsNotNone(str(container), message) # added as a test to make sure printing works
Expand Down Expand Up @@ -227,13 +223,11 @@ def get_manager(self):
def roundtripExportContainer(self, container, write_path, export_path):
with HDF5IO(write_path, manager=self.get_manager(), mode='w') as write_io:
write_io.write(container, cache_spec=True)
# breakpoint()
with HDF5IO(write_path, manager=self.get_manager(), mode='r') as read_io:
with ZarrIO(export_path, mode='w') as export_io:
export_io.export(src_io=read_io, write_args={'link_data': False})

read_io = ZarrIO(export_path, manager=self.get_manager(), mode='r')
# breakpoint()
self.ios.append(read_io)
exportContainer = read_io.read()
return exportContainer
Expand All @@ -259,13 +253,11 @@ def get_manager(self):
def roundtripExportContainer(self, container, write_path, export_path):
with ZarrIO(write_path, manager=self.get_manager(), mode='w') as write_io:
write_io.write(container)
# breakpoint()
with ZarrIO(write_path, manager=self.get_manager(), mode='r') as read_io:
with HDF5IO(export_path, mode='w') as export_io:
export_io.export(src_io=read_io, write_args={'link_data': False})

read_io = HDF5IO(export_path, manager=self.get_manager(), mode='r')
# breakpoint()
self.ios.append(read_io)
exportContainer = read_io.read()
return exportContainer
Expand Down Expand Up @@ -295,13 +287,11 @@ def get_manager(self):
def roundtripExportContainer(self, container, write_path, export_path):
with ZarrIO(write_path, manager=self.get_manager(), mode='w') as write_io:
write_io.write(container, cache_spec=True)
# breakpoint()
with ZarrIO(write_path, manager=self.get_manager(), mode='r') as read_io:
with ZarrIO(export_path, mode='w') as export_io:
export_io.export(src_io=read_io, write_args={'link_data': False})

read_io = ZarrIO(export_path, manager=self.get_manager(), mode='r')
# breakpoint()
self.ios.append(read_io)
exportContainer = read_io.read()
return exportContainer
Expand Down

0 comments on commit e170307

Please sign in to comment.