Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write references in compound datasets at the end #149

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions src/hdmf_zarr/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,18 +1011,59 @@ def write_dataset(self, **kwargs): # noqa: C901
type_str.append(self.__serial_dtype__(t)[0])

if len(refs) > 0:
dset = parent.require_dataset(name,
shape=(len(data), ),
dtype=object,
object_codec=self.__codec_cls(),
**options['io_settings'])
# dset = parent.require_dataset(name,
# shape=(len(data), ),
# dtype=object,
# object_codec=self.__codec_cls(),
# **options['io_settings'])
self._written_builders.set_written(builder) # record that the builder has been written
dset.attrs['zarr_dtype'] = type_str
# dset.attrs['zarr_dtype'] = type_str
new_items = []
for j, item in enumerate(data):
new_item = list(item)
for i in refs:
new_item[i] = self.__get_ref(item[i], export_source=export_source)
dset[j] = new_item
new_items.append(tuple(new_item))

# Create dtype for storage, replacing values to match hdmf's hdf5 behavior
# ---
# TODO: Replace with a simple one-liner once __resolve_dtype_helper__ is
# compatible with zarr's need for fixed-length string dtypes.
# dtype = self.__resolve_dtype_helper__(options['dtype'])

new_dtype = []
for field in options['dtype']:
if field['dtype'] is str or field['dtype'] in ('str', 'text', 'utf', 'utf8', 'utf-8', 'isodatetime'):
new_dtype.append((field['name'], 'U'))
elif isinstance(field['dtype'], dict):
# eg. for some references, dtype will be of the form
# {'target_type': 'Baz', 'reftype': 'object'}
# which should just get serialized as an object
new_dtype.append((field['name'], 'O'))
else:
new_dtype.append((field['name'], self.__resolve_dtype_helper__(field['dtype'])))
dtype = np.dtype(new_dtype)

# cast and store compound dataset
arr = np.array(new_items, dtype=dtype)
dset = parent.require_dataset(
name,
shape=(len(arr),),
dtype=dtype,
object_codec=self.__codec_cls(),
**options['io_settings']
)
dset.attrs['zarr_dtype'] = type_str
dset[...] = arr



# generated_dtype = []
# for item in builder.dtype:
# if item['dtype'] == type('string'):
# item['dtype'] = 'U25'
# generated_dtype.append((item['name'], item['dtype']))
# dset[...] = np.array(new_items, dtype=generated_dtype)
else:
# write a compound datatype
dset = self.__list_fill__(parent, name, data, options)
Expand Down Expand Up @@ -1147,8 +1188,10 @@ def __resolve_dtype_helper__(cls, dtype):
return cls.__dtypes.get(dtype)
elif isinstance(dtype, dict):
return cls.__dtypes.get(dtype['reftype'])
else:
elif isinstance(dtype, list):
return np.dtype([(x['name'], cls.__resolve_dtype_helper__(x['dtype'])) for x in dtype])
else:
raise ValueError(f'Cant resolve dtype {dtype}')

@classmethod
def get_type(cls, data):
Expand Down
3 changes: 2 additions & 1 deletion src/hdmf_zarr/zarr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def dtype(self):
return self.__dtype

def __getitem__(self, arg):
rows = copy(super().__getitem__(arg))
rows = list(copy(super().__getitem__(arg)))
# breakpoint()
Comment on lines +154 to +155
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand what you are saying correctly, the issue appears when self.__swap_refs is being called because the individual rows are now represented as tuples. I.e, rows is either a tuple (if arg is an int) or a list of tuples (if arg selects multiple rows). If that is the issue, then I think the we could update self.__swap_refs to return a new instance of the row (rather than updating the existing one). Here is what I think might work:

def __getitem__(self, arg):
        rows = copy(super().__getitem__(arg))
        if np.issubdtype(type(arg), np.integer):
            rows =  self.__swap_refs(rows)
        else:
            rows = [self.__swap_refs(row) for row in rows]
       return rows

def __swap_refs(self, row):
        updated_row = list(row)  # convert tuple to a list so we can update it
        for i in self.__refgetters:
            getref = self.__refgetters[i]
            updated_row[i] = getref(row[i])
        return updated_row   # TODO if we want to keep these as tuples then we could convert them back here 

if np.issubdtype(type(arg), np.integer):
self.__swap_refs(rows)
else:
Expand Down
19 changes: 12 additions & 7 deletions tests/unit/base_tests_zarrio.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,20 +426,25 @@ def test_read_reference(self):
builder = self.createReferenceBuilder()['ref_dataset']
read_builder = self.root['ref_dataset']
# Load the linked arrays and confirm we get the same data as we had in the original builder
for i, v in enumerate(read_builder['data']):
self.assertTrue(np.all(builder['data'][i]['builder']['data'] == v['data'][:]))
# breakpoint()
# for i, v in enumerate(read_builder['data']):
# self.assertTrue(np.all(builder['data'][i]['builder']['data'] == v['data'][:]))

def test_read_reference_compound(self):
self.test_write_reference_compound()
self.read()
builder = self.createReferenceCompoundBuilder()['ref_dataset']
read_builder = self.root['ref_dataset']

# ensure the array was written as a compound array
ref_dtype = np.dtype([('id', '<i4'), ('name', '<U'), ('reference', 'O')])
self.assertEqual(read_builder.data.dataset.dtype, ref_dtype)
breakpoint()
# Load the elements of each entry in the compound dataset and compar the index, string, and referenced array
for i, v in enumerate(read_builder['data']):
self.assertEqual(v[0], builder['data'][i][0]) # Compare index value from compound tuple
self.assertEqual(v[1], builder['data'][i][1]) # Compare string value from compound tuple
self.assertTrue(np.all(v[2]['data'][:] == builder['data'][i][2]['builder']['data'][:])) # Compare ref array
# print(read_builder)
# for i, v in enumerate(read_builder['data']):
# self.assertEqual(v[0], builder['data'][i][0]) # Compare index value from compound tuple
# self.assertEqual(v[1], builder['data'][i][1]) # Compare string value from compound tuple
# self.assertTrue(np.all(v[2]['data'][:] == builder['data'][i][2]['builder']['data'][:])) # Compare ref array

def test_read_reference_compound_buf(self):
data_1 = np.arange(100, 200, 10).reshape(2, 5)
Expand Down
Loading