Skip to content

Commit

Permalink
Write scalar datasets with compound data type (#1176)
Browse files Browse the repository at this point in the history
* add support for scalar compound datasets

* add scalar compound dset io and validation tests

* update CHANGELOG.md

* Update tests/unit/test_io_hdf5_h5tools.py

Co-authored-by: Ryan Ly <[email protected]>

* update container repr conditionals

---------

Co-authored-by: Ryan Ly <[email protected]>
  • Loading branch information
stephprince and rly authored Aug 22, 2024
1 parent 2b167ae commit acc3d78
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
- Improved "already exists" error message when adding a container to a `MultiContainerInterface`. @rly [#1165](https://github.com/hdmf-dev/hdmf/pull/1165)
- Added support to write multidimensional string arrays. @stephprince [#1173](https://github.com/hdmf-dev/hdmf/pull/1173)

### Bug fixes
- Fixed issue where scalar datasets with a compound data type were being written as non-scalar datasets @stephprince [#1176](https://github.com/hdmf-dev/hdmf/pull/1176)

## HDMF 3.14.3 (July 29, 2024)

### Enhancements
Expand Down
4 changes: 4 additions & 0 deletions src/hdmf/backends/hdf5/h5tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,8 @@ def __read_dataset(self, h5obj, name=None):
d = ReferenceBuilder(target_builder)
kwargs['data'] = d
kwargs['dtype'] = d.dtype
elif h5obj.dtype.kind == 'V': # scalar compound data type
kwargs['data'] = np.array(scalar, dtype=h5obj.dtype)
else:
kwargs["data"] = scalar
else:
Expand Down Expand Up @@ -1227,6 +1229,8 @@ def _filler():

return
# If the compound data type contains only regular data (i.e., no references) then we can write it as usual
elif len(np.shape(data)) == 0:
dset = self.__scalar_fill__(parent, name, data, options)
else:
dset = self.__list_fill__(parent, name, data, options)
# Write a dataset containing references, i.e., a region or object reference.
Expand Down
6 changes: 1 addition & 5 deletions src/hdmf/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,8 @@ def __repr__(self):
template += "\nFields:\n"
for k in sorted(self.fields): # sorted to enable tests
v = self.fields[k]
# if isinstance(v, DataIO) or not hasattr(v, '__len__') or len(v) > 0:
if hasattr(v, '__len__'):
if isinstance(v, (np.ndarray, list, tuple)):
if len(v) > 0:
template += " {}: {}\n".format(k, self.__smart_str(v, 1))
elif v:
if isinstance(v, (np.ndarray, list, tuple)) or v:
template += " {}: {}\n".format(k, self.__smart_str(v, 1))
else:
template += " {}: {}\n".format(k, v)
Expand Down
13 changes: 9 additions & 4 deletions src/hdmf/validate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_type(data, builder_dtype=None):
elif isinstance(data, ReferenceResolver):
return data.dtype, None
# Numpy nd-array data
elif isinstance(data, np.ndarray):
elif isinstance(data, np.ndarray) and len(data.dtype) <= 1:
if data.size > 0:
return get_type(data[0], builder_dtype)
else:
Expand All @@ -147,11 +147,14 @@ def get_type(data, builder_dtype=None):
# Case for h5py.Dataset and other I/O specific array types
else:
# Compound dtype
if builder_dtype and isinstance(builder_dtype, list):
if builder_dtype and len(builder_dtype) > 1:
dtypes = []
string_formats = []
for i in range(len(builder_dtype)):
dtype, string_format = get_type(data[0][i])
if len(np.shape(data)) == 0:
dtype, string_format = get_type(data[()][i])
else:
dtype, string_format = get_type(data[0][i])
dtypes.append(dtype)
string_formats.append(string_format)
return dtypes, string_formats
Expand Down Expand Up @@ -438,7 +441,9 @@ def validate(self, **kwargs):
except EmptyArrayError:
# do not validate dtype of empty array. HDMF does not yet set dtype when writing a list/tuple
pass
if isinstance(builder.dtype, list):
if builder.dtype is not None and len(builder.dtype) > 1 and len(np.shape(builder.data)) == 0:
shape = () # scalar compound dataset
elif isinstance(builder.dtype, list):
shape = (len(builder.data), ) # only 1D datasets with compound types are supported
else:
shape = get_data_shape(data)
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_io_hdf5_h5tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ def test_write_dataset_string(self):
read_a = read_a.decode('utf-8')
self.assertEqual(read_a, a)

def test_write_dataset_scalar_compound(self):
cmpd_dtype = np.dtype([('x', np.int32), ('y', np.float64)])
a = np.array((1, 0.1), dtype=cmpd_dtype)
self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a,
dtype=[DtypeSpec('x', doc='x', dtype='int32'),
DtypeSpec('y', doc='y', dtype='float64')]))
dset = self.f['test_dataset']
self.assertTupleEqual(dset.shape, ())
self.assertEqual(dset[()].tolist(), a.tolist())

##########################################
# write_dataset tests: TermSetWrapper
##########################################
Expand Down Expand Up @@ -787,6 +797,17 @@ def test_read_str(self):
self.assertEqual(str(bldr['test_dataset'].data),
'<HDF5 dataset "test_dataset": shape (5,), type "|O">')

def test_read_scalar_compound(self):
cmpd_dtype = np.dtype([('x', np.int32), ('y', np.float64)])
a = np.array((1, 0.1), dtype=cmpd_dtype)
self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a,
dtype=[DtypeSpec('x', doc='x', dtype='int32'),
DtypeSpec('y', doc='y', dtype='float64')]))
self.io.close()
with HDF5IO(self.path, 'r') as io:
bldr = io.read_builder()
np.testing.assert_array_equal(bldr['test_dataset'].data[()], a)


class TestRoundTrip(TestCase):

Expand Down
22 changes: 22 additions & 0 deletions tests/unit/validator_tests/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,28 @@ def test_np_bool_for_bool(self):
results = self.vmap.validate(bar_builder)
self.assertEqual(len(results), 0)

def test_scalar_compound_dtype(self):
"""Test that validator allows scalar compound dtype data where a compound dtype is specified."""
spec_catalog = SpecCatalog()
dtype = [DtypeSpec('x', doc='x', dtype='int'), DtypeSpec('y', doc='y', dtype='float')]
spec = GroupSpec('A test group specification with a data type',
data_type_def='Bar',
datasets=[DatasetSpec('an example dataset', dtype, name='data',)],
attributes=[AttributeSpec('attr1', 'an example attribute', 'text',)])
spec_catalog.register_spec(spec, 'test2.yaml')
self.namespace = SpecNamespace(
'a test namespace', CORE_NAMESPACE, [{'source': 'test2.yaml'}], version='0.1.0', catalog=spec_catalog)
self.vmap = ValidatorMap(self.namespace)

value = np.array((1, 2.2), dtype=[('x', 'int'), ('y', 'float')])
bar_builder = GroupBuilder('my_bar',
attributes={'data_type': 'Bar', 'attr1': 'test'},
datasets=[DatasetBuilder(name='data',
data=value,
dtype=[DtypeSpec('x', doc='x', dtype='int'),
DtypeSpec('y', doc='y', dtype='float'),],),])
results = self.vmap.validate(bar_builder)
self.assertEqual(len(results), 0)

class Test1DArrayValidation(TestCase):

Expand Down

0 comments on commit acc3d78

Please sign in to comment.