Skip to content

Commit

Permalink
chckpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
mavaylon1 committed Sep 21, 2023
1 parent 7f87966 commit 1603d66
Show file tree
Hide file tree
Showing 3 changed files with 1,038 additions and 848 deletions.
4 changes: 2 additions & 2 deletions docs/gallery/plot_convert_nwb_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@

zf.trials.to_dataframe()[['start_time', 'stop_time', 'type', 'photo_stim_type']]
zr.close()
# breakpoint()

###############################################################################
# Convert the Zarr file back to HDF5
# ----------------------------------
#
# Using the same approach as above, we can now convert our Zarr file back to HDF5.
try: # TODO: This is a temporary ignore on the convert_dtype exception.
try: # TODO: This is a temporary ignore on the convert_dtype exception.
with NWBZarrIO(zarr_filename, mode='r') as read_io: # Create Zarr IO object for read
with NWBHDF5IO(hdf_filename, 'w') as export_io: # Create HDF5 IO object for write
export_io.export(src_io=read_io, write_args=dict(link_data=False)) # Export from Zarr to HDF5
Expand Down
60 changes: 58 additions & 2 deletions src/hdmf_zarr/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ def export(self, **kwargs):
"link_data=True." % src_io.__class__.__name__)

# write_args['export_source'] = src_io.source # pass export_source=src_io.source to write_builder
# write_args['export_source'] = os.path.abspath(src_io.source) if src_io.source is not None else None


ckwargs = kwargs.copy()
ckwargs['write_args'] = write_args
super().export(**ckwargs)
Expand Down Expand Up @@ -300,6 +303,8 @@ def get_builder_disk_path(self, **kwargs):
def write_builder(self, **kwargs):
"""Write a builder to disk"""
f_builder, link_data, exhaust_dci = getargs('builder', 'link_data', 'exhaust_dci', kwargs)
# f_builder = popargs('builder', kwargs)
# link_data, exhaust_dci, export_source = getargs('link_data', 'exhaust_dci', 'export_source', kwargs)
for name, gbldr in f_builder.groups.items():
self.write_group(parent=self.__file,
builder=gbldr,
Expand All @@ -310,7 +315,9 @@ def write_builder(self, **kwargs):
builder=dbldr,
link_data=link_data,
exhaust_dci=exhaust_dci)
self.write_attributes(self.__file, f_builder.attributes)
# for name, lbldr in f_builder.links.items():
# self.write_link(self.__file, lbldr, export_source=kwargs.get("export_source"))
self.write_attributes(self.__file, f_builder.attributes) # the same as set_attributes in HDMF
self.__dci_queue.exhaust_queue() # Write all DataChunkIterators that have been queued
self._written_builders.set_written(f_builder)
self.logger.debug("Done writing %s '%s' to path '%s'" %
Expand Down Expand Up @@ -517,8 +524,13 @@ def resolve_ref(self, zarr_ref):
else:
source_file = str(zarr_ref['source'])
# Resolve the path relative to the current file
# breakpoint()
source_file = os.path.abspath(os.path.join(self.source, source_file))
object_path = zarr_ref.get('path', None)
# if object_path == "/bucket1/bazs/baz0":
# object_path = "/root/bazs/baz0"
# if object_path == "/bazs/baz0":

# full_path = None
# if os.path.isdir(source_file):
# if object_path is not None:
Expand All @@ -534,7 +546,15 @@ def resolve_ref(self, zarr_ref):
try:
target_zarr_obj = target_zarr_obj[object_path]
except Exception:
raise ValueError("Found bad link to object %s in file %s" % (object_path, source_file))
# breakpoint()
try:
import pathlib
object_path = pathlib.Path(object_path)
rel_obj_path = object_path.relative_to(*object_path.parts[:2])
target_zarr_obj = target_zarr_obj[rel_obj_path]
except Exception:
# breakpoint()
raise ValueError("Found bad link to object %s in file %s" % (object_path, source_file))
# Return the create path
return target_name, target_zarr_obj

Expand Down Expand Up @@ -698,6 +718,7 @@ def __setup_chunked_dataset__(cls, parent, name, data, options=None):
'doc': 'Used internally to force the data being used when we have to load the data', 'default': None},
returns='the Zarr array that was created', rtype=Array)
def write_dataset(self, **kwargs): # noqa: C901
# breakpoint()
parent, builder, link_data, exhaust_dci = getargs('parent', 'builder', 'link_data', 'exhaust_dci', kwargs)
force_data = getargs('force_data', kwargs)
if self.get_written(builder):
Expand Down Expand Up @@ -730,6 +751,7 @@ def write_dataset(self, **kwargs): # noqa: C901
dset = parent[name]
# When converting data between backends we may see an HDMFDataset, e.g., a H55ReferenceDataset, with references
elif isinstance(data, HDMFDataset):
# breakpoint()
# If we have a dataset of containers we need to make the references to the containers
if len(data) > 0 and isinstance(data[0], Container):
ref_data = [self.__get_ref(data[i]) for i in range(len(data))]
Expand All @@ -742,6 +764,7 @@ def write_dataset(self, **kwargs): # noqa: C901
**options['io_settings'])
dset.attrs['zarr_dtype'] = type_str
dset[:] = ref_data
# breakpoint()
self._written_builders.set_written(builder) # record that the builder has been written
# If we have a regular dataset, then load the data and write the builder after load
else:
Expand All @@ -752,6 +775,8 @@ def write_dataset(self, **kwargs): # noqa: C901
# We can/should not update the data in the builder itself so we load the data here and instead
# force write_dataset when we call it recursively to use the data we loaded, rather than the
# dataset that is set on the builder
# breakpoint()

dset = self.write_dataset(parent=parent,
builder=builder,
link_data=link_data,
Expand All @@ -774,6 +799,7 @@ def write_dataset(self, **kwargs): # noqa: C901
i = list([dts, ])
t = self.__resolve_dtype_helper__(i)
type_str.append(self.__serial_dtype__(t)[0])
# breakpoint()

if len(refs) > 0:
dset = parent.require_dataset(name,
Expand Down Expand Up @@ -1050,7 +1076,23 @@ def get_builder(self, **kwargs): # move this to HDMFIO (define skeleton in there
"""
zarr_obj = kwargs['zarr_obj']
builder = self.__get_built(zarr_obj)
# ff = self.__built
# breakpoint()
# if zarr_obj.name == '/bazs/baz0':
# breakpoint()
if builder is None:
#
# breakpoint()
# builder = self.__temp_get_built(zarr_obj)
path = list(self.__built.keys())[0]
builder_source_path = path.replace(zarr_obj.path,'')

zarr_obj_path = zarr_obj.path
path = os.path.join(builder_source_path, path)
# breakpoint()
builder = self.__built.get(path, None)
if builder is None:
# breakpoint()
msg = '%s has not been built' % (zarr_obj.name)
raise ValueError(msg)
return builder
Expand All @@ -1064,9 +1106,23 @@ def __get_built(self, zarr_obj):
"""
fpath = zarr_obj.store.path
path = zarr_obj.path
ff = self.__built
# breakpoint()
path = os.path.join(fpath, path)
return self.__built.get(path, None)

def __temp_get_built(self, zarr_obj):
fpath = zarr_obj.store.path
path = zarr_obj.path
path = os.path.join('bucket1', path)
# breakpoint()
path = os.path.join(fpath, path)
# breakpoint()

builder_source_path = list(self.__built.keys())[0]
return self.__built.get(path, None)


def __read_group(self, zarr_obj, name=None):
ret = self.__get_built(zarr_obj)
if ret is not None:
Expand Down
Loading

0 comments on commit 1603d66

Please sign in to comment.