Skip to content

Commit

Permalink
Add tests and change zarr array handling
Browse files Browse the repository at this point in the history
  • Loading branch information
rly committed Aug 20, 2024
1 parent 2b90d21 commit bb3de8b
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
- Improved "already exists" error message when adding a container to a `MultiContainerInterface`. @rly [#1165](https://github.com/hdmf-dev/hdmf/pull/1165)

### Bug fixes
- Fixed bug when converting string datasets from Zarr to HDF5. @oruebel [#1171](https://github.com/hdmf-dev/hdmf/pull/1171)
- Fixed bug when converting string datasets from Zarr to HDF5. @oruebel @rly [#1171](https://github.com/hdmf-dev/hdmf/pull/1171)

## HDMF 3.14.3 (July 29, 2024)

Expand Down
22 changes: 18 additions & 4 deletions src/hdmf/backends/hdf5/h5tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
from ...utils import docval, getargs, popargs, get_data_shape, get_docval, StrDataset
from ..utils import NamespaceToBuilderHelper, WriteStatusTracker

# try:
# from zarr import Array as ZarrArray
# import numcodecs
# ZARR_INSTALLED = True
# except ImportError:
# ZARR_INSTALLED = False

ROOT_NAME = 'root'
SPEC_LOC_ATTR = '.specloc'
H5_TEXT = special_dtype(vlen=str)
Expand All @@ -34,6 +41,10 @@
H5PY_3 = h5py.__version__.startswith('3')


# def _is_zarr_array(value):
# return ZARR_INSTALLED and isinstance(value, ZarrArray)


class HDF5IO(HDMFIO):

__ns_spec_path = 'namespace' # path to the namespace dataset within a namespace group
Expand Down Expand Up @@ -924,10 +935,13 @@ def __resolve_dtype__(cls, dtype, data):
# binary
# number

# Use text dtype for Zarr datasets of strings. Zarr stores variable length strings
# as objects, so we need to detect this special case here
if hasattr(data, 'attrs') and 'zarr_dtype' in data.attrs and data.attrs['zarr_dtype'] == 'str':
return cls.__dtypes['text']
# # Use text dtype for Zarr datasets of strings. Zarr stores variable length strings
# # as objects, so we need to detect this special case here
# if _is_zarr_array(data) and data.filters:
# if numcodecs.VLenUTF8() in data.filters:
# return cls.__dtypes['text']
# elif numcodecs.VLenBytes() in data.filters:
# return cls.__dtypes['ascii']

dtype = cls.__resolve_dtype_helper__(dtype)
if dtype is None:
Expand Down
20 changes: 16 additions & 4 deletions src/hdmf/build/objectmapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
from ..spec.spec import BaseStorageSpec
from ..utils import docval, getargs, ExtenderMeta, get_docval, get_data_shape

try:
from zarr import Array as ZarrArray
ZARR_INSTALLED = True

Check warning on line 25 in src/hdmf/build/objectmapper.py

View check run for this annotation

Codecov / codecov/patch

src/hdmf/build/objectmapper.py#L25

Added line #L25 was not covered by tests
except ImportError:
ZARR_INSTALLED = False


def _is_zarr_array(value):
return ZARR_INSTALLED and isinstance(value, ZarrArray)

_const_arg = '__constructor_arg'


Expand Down Expand Up @@ -206,17 +216,19 @@ def convert_dtype(cls, spec, value, spec_dtype=None): # noqa: C901
spec_dtype_type = cls.__dtypes[spec_dtype]
warning_msg = None
# Numpy Array or Zarr array
if (isinstance(value, np.ndarray) or
(hasattr(value, 'astype') and hasattr(value, 'dtype'))):
# NOTE: Numpy < 2.0 has only fixed-length strings.
# Numpy 2.0 introduces variable-length strings (dtype=np.dtypes.StringDType()).
# HDMF does not yet do any special handling of numpy arrays with variable-length strings.
if isinstance(value, np.ndarray) or _is_zarr_array(value):
if spec_dtype_type is _unicode:
if hasattr(value, 'attrs') and 'zarr_dtype' in value.attrs:
if _is_zarr_array(value):
# Zarr stores strings as objects, so we cannot convert to unicode dtype
ret = value
else:
ret = value.astype('U')
ret_dtype = "utf8"
elif spec_dtype_type is _ascii:
if hasattr(value, 'attrs') and 'zarr_dtype' in value.attrs:
if _is_zarr_array(value):
# Zarr stores strings as objects, so we cannot convert to unicode dtype
ret = value
else:
Expand Down
50 changes: 49 additions & 1 deletion tests/unit/build_tests/test_convert_dtype.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from datetime import datetime, date

import numpy as np
import unittest

from hdmf.backends.hdf5 import H5DataIO
from hdmf.build import ObjectMapper
from hdmf.data_utils import DataChunkIterator
from hdmf.spec import DatasetSpec, RefSpec, DtypeSpec
from hdmf.testing import TestCase

try:
import zarr
import numcodecs
SKIP_ZARR_TESTS = False
except ImportError:
SKIP_ZARR_TESTS = True


class TestConvertDtype(TestCase):

Expand Down Expand Up @@ -551,3 +559,43 @@ def test_isodate_spec(self):
self.assertEqual(ret, b'2020-11-10')
self.assertIs(type(ret), bytes)
self.assertEqual(ret_dtype, 'ascii')

@unittest.skipIf(SKIP_ZARR_TESTS, "Zarr is not installed")
def test_zarr_array_spec_vlen_utf8(self):
"""Test that converting a zarr array with utf8 dtype for a variable length utf8 dtype spec
returns the same object with a utf8 ret_dtype."""
spec = DatasetSpec('an example dataset', 'text', name='data')

value = zarr.array(['a', 'b']) # fixed length unicode (dtype = <U1)
ret, ret_dtype = ObjectMapper.convert_dtype(spec, value)
self.assertEqual(ret, value)
self.assertIs(type(ret), zarr.Array)
self.assertIs(ret.dtype.type, np.str_)
self.assertEqual(ret_dtype, 'utf8')

value = zarr.array(['a', 'b'], dtype=object, object_codec=numcodecs.VLenUTF8()) # variable length unicode
ret, ret_dtype = ObjectMapper.convert_dtype(spec, value)
self.assertEqual(ret, value)
self.assertIs(type(ret), zarr.Array)
self.assertIs(ret.dtype.type, np.object_)
self.assertEqual(ret_dtype, 'utf8')

@unittest.skipIf(SKIP_ZARR_TESTS, "Zarr is not installed")
def test_zarr_array_spec_vlen_ascii(self):
"""Test that converting a zarr array with fixed length utf8 dtype for a variable length ascii dtype spec
returns the same object with a ascii ret_dtype."""
spec = DatasetSpec('an example dataset', 'ascii', name='data')

value = zarr.array(['a', 'b']) # fixed length unicode (dtype = <U1)
ret, ret_dtype = ObjectMapper.convert_dtype(spec, value)
self.assertEqual(ret, value)
self.assertIs(type(ret), zarr.Array)
self.assertIs(ret.dtype.type, np.str_) # the zarr array is not converted
self.assertEqual(ret_dtype, 'ascii') # the dtype of the builder will be ascii

value = zarr.array(['a', 'b'], dtype=object, object_codec=numcodecs.VLenUTF8()) # variable length unicode
ret, ret_dtype = ObjectMapper.convert_dtype(spec, value)
self.assertEqual(ret, value)
self.assertIs(type(ret), zarr.Array)
self.assertIs(ret.dtype.type, np.object_) # the zarr array is not converted
self.assertEqual(ret_dtype, 'ascii') # the dtype of the builder will be ascii
114 changes: 110 additions & 4 deletions tests/unit/test_io_hdf5_h5tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

try:
import zarr
import numcodecs
SKIP_ZARR_TESTS = False
except ImportError:
SKIP_ZARR_TESTS = True
Expand Down Expand Up @@ -3538,16 +3539,121 @@ def test_write_zarr_float32_dataset(self):
self.assertListEqual(dset[:].tolist(),
base_data.tolist())

def test_write_zarr_string_dataset(self):
def test_write_zarr_flen_utf8_dataset(self):
# fixed length unicode zarr array
base_data = np.array(['string1', 'string2'], dtype=str)
zarr.save(self.zarr_path, base_data)
zarr_data = zarr.open(self.zarr_path, 'r')
io = HDF5IO(self.path, mode='a')
f = io._file
io.write_dataset(f, DatasetBuilder('test_dataset', zarr_data, attributes={}))
dset = f['test_dataset']

io.write_dataset(f, DatasetBuilder('test_dataset1', zarr_data)) # no dtype specified
dset = f['test_dataset1']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), str) # check that the dtype is str
self.assertTupleEqual(dset.shape, (2,))
np.testing.assert_array_equal(dset[:].astype(str), base_data)

io.write_dataset(f, DatasetBuilder('test_dataset2', zarr_data, dtype="utf8")) # utf8 dtype specified
dset = f['test_dataset2']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), str) # check that the dtype is str
self.assertTupleEqual(dset.shape, (2,))
np.testing.assert_array_equal(dset[:].astype(str), zarr_data[:])

io.write_dataset(f, DatasetBuilder('test_dataset3', zarr_data, dtype="ascii")) # ascii dtype specified
dset = f['test_dataset3']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), bytes) # check that the dtype is bytes
self.assertTupleEqual(dset.shape, (2,))
np.testing.assert_array_equal(dset[:].astype(str), zarr_data[:])

def test_write_zarr_flen_ascii_dataset(self):
# fixed length ascii zarr array
base_data = np.array(['string1', 'string2'], dtype=bytes)
zarr.save(self.zarr_path, base_data)
zarr_data = zarr.open(self.zarr_path, 'r')
io = HDF5IO(self.path, mode='a')
f = io._file

io.write_dataset(f, DatasetBuilder('test_dataset1', zarr_data)) # no dtype specified
dset = f['test_dataset1']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), bytes) # check that the dtype is bytes
self.assertTupleEqual(dset.shape, (2,))
np.testing.assert_array_equal(dset[:].astype(bytes), base_data)

io.write_dataset(f, DatasetBuilder('test_dataset2', zarr_data, dtype="utf8")) # utf8 dtype specified
dset = f['test_dataset2']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), str) # check that the dtype is str
self.assertTupleEqual(dset.shape, (2,))
np.testing.assert_array_equal(dset[:].astype(bytes), zarr_data[:])

io.write_dataset(f, DatasetBuilder('test_dataset3', zarr_data, dtype="ascii")) # ascii dtype specified
dset = f['test_dataset3']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), bytes) # check that the dtype is bytes
self.assertTupleEqual(dset.shape, (2,))
np.testing.assert_array_equal(dset[:].astype(bytes), zarr_data[:])

def test_write_zarr_vlen_utf8_dataset(self):
# variable length unicode zarr array
base_data = np.array(['string1', 'string2'], dtype=str)
zarr_data = zarr.open(self.zarr_path, shape=(2,), dtype=object, object_codec=numcodecs.VLenUTF8())
zarr_data[:] = base_data
io = HDF5IO(self.path, mode='a')
f = io._file

io.write_dataset(f, DatasetBuilder('test_dataset1', zarr_data)) # no dtype specified
dset = f['test_dataset1']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), str) # check that the dtype is str
self.assertTupleEqual(dset.shape, (2,))
np.testing.assert_array_equal(dset[:].astype(str), base_data)

io.write_dataset(f, DatasetBuilder('test_dataset2', zarr_data, dtype="utf8")) # utf8 dtype specified
dset = f['test_dataset2']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), str) # check that the dtype is str
self.assertTupleEqual(dset.shape, (2,))
np.testing.assert_array_equal(dset[:].astype(str), zarr_data[:])

io.write_dataset(f, DatasetBuilder('test_dataset3', zarr_data, dtype="ascii")) # ascii dtype specified
dset = f['test_dataset3']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), bytes) # check that the dtype is bytes
self.assertTupleEqual(dset.shape, (2,))
np.testing.assert_array_equal(dset[:].astype(str), zarr_data[:])

def test_write_zarr_vlen_ascii_dataset(self):
# variable length ascii zarr array
base_data = np.array(['string1', 'string2'], dtype=bytes)
zarr_data = zarr.open(self.zarr_path, shape=(2,), dtype=object, object_codec=numcodecs.VLenBytes())
zarr_data[:] = base_data
io = HDF5IO(self.path, mode='a')
f = io._file

io.write_dataset(f, DatasetBuilder('test_dataset1', zarr_data)) # no dtype specified
dset = f['test_dataset1']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), bytes) # check that the dtype is bytes
self.assertTupleEqual(dset.shape, (2,))
np.testing.assert_array_equal(dset[:].astype(bytes), base_data)

io.write_dataset(f, DatasetBuilder('test_dataset2', zarr_data, dtype="utf8")) # utf8 dtype specified
dset = f['test_dataset2']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), str) # check that the dtype is str
self.assertTupleEqual(dset.shape, (2,))
np.testing.assert_array_equal(dset[:].astype(bytes), zarr_data[:])

io.write_dataset(f, DatasetBuilder('test_dataset3', zarr_data, dtype="ascii")) # ascii dtype specified
dset = f['test_dataset3']
self.assertIs(dset.dtype.type, np.object_)
self.assertEqual(h5py.check_dtype(vlen=dset.dtype), bytes) # check that the dtype is bytes
self.assertTupleEqual(dset.shape, (2,))
self.assertListEqual(dset[:].astype(bytes).tolist(), base_data.astype(bytes).tolist())
np.testing.assert_array_equal(dset[:].astype(bytes), zarr_data[:])

def test_write_zarr_dataset_compress_gzip(self):
base_data = np.arange(50).reshape(5, 10).astype('float32')
Expand Down

0 comments on commit bb3de8b

Please sign in to comment.