Skip to content

Commit

Permalink
Merge pull request #35 from NeurodataWithoutBorders/codec
Browse files Browse the repository at this point in the history
Specify zarr codec when creating a dataset
  • Loading branch information
magland authored Apr 19, 2024
2 parents 40c94a1 + 22fc137 commit 3c4562f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 6 deletions.
46 changes: 44 additions & 2 deletions lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import h5py
import numpy as np
import zarr
import numcodecs
from numcodecs.abc import Codec

from ..LindiH5pyDataset import LindiH5pyDataset
from ..LindiH5pyReference import LindiH5pyReference
Expand All @@ -11,6 +13,8 @@

from ...conversion.create_zarr_dataset_from_h5_data import create_zarr_dataset_from_h5_data

_compression_not_specified_ = object()


class LindiH5pyGroupWriter:
def __init__(self, p: 'LindiH5pyGroup'):
Expand Down Expand Up @@ -39,15 +43,52 @@ def require_group(self, name):
return ret
return self.create_group(name)

def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds):
def create_dataset(
self,
name,
shape=None,
dtype=None,
data=None,
**kwds
):
chunks = None
compression = _compression_not_specified_
compression_opts = None
for k, v in kwds.items():
if k == 'chunks':
chunks = v
elif k == 'compression':
compression = v
elif k == 'compression_opts':
compression_opts = v
else:
raise Exception(f'Unsupported kwds in create_dataset: {k}')

if compression is _compression_not_specified_:
_zarr_compressor = 'default'
if compression_opts is not None:
raise Exception('compression_opts is only supported when compression is provided')
elif isinstance(compression, Codec):
_zarr_compressor = compression
if compression_opts is not None:
raise Exception('compression_opts is not supported when compression is provided as a Codec')
elif isinstance(compression, str):
if compression == 'gzip':
if compression_opts is None:
level = 4 # default for h5py
elif isinstance(compression_opts, int):
level = compression_opts
else:
raise Exception(f'Unexpected type for compression_opts: {type(compression_opts)}')
_zarr_compressor = numcodecs.GZip(level=level)
else:
raise Exception(f'Compression {compression} is not supported')
else:
raise Exception(f'Unexpected type for compression: {type(compression)}')

if isinstance(self.p._group_object, h5py.Group):
if _zarr_compressor != 'default':
raise Exception('zarr_compressor is not supported when _group_object is h5py.Group')
return LindiH5pyDataset(
self._group_object.create_dataset(name, shape=shape, dtype=dtype, data=data, chunks=chunks), # type: ignore
self.p._file
Expand Down Expand Up @@ -77,7 +118,8 @@ def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds):
h5_shape=shape,
h5_dtype=dtype,
h5_data=data,
h5f=None
h5f=None,
zarr_compressor=_zarr_compressor
)
return LindiH5pyDataset(ds, self.p._file)
else:
Expand Down
23 changes: 20 additions & 3 deletions lindi/conversion/create_zarr_dataset_from_h5_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Union, List, Any, Tuple
from typing import Union, List, Any, Tuple, Literal
from dataclasses import dataclass
import numpy as np
import numcodecs
from numcodecs.abc import Codec
import h5py
import zarr
from .h5_ref_to_zarr_attr import h5_ref_to_zarr_attr
Expand All @@ -17,7 +18,8 @@ def create_zarr_dataset_from_h5_data(
h5f: Union[h5py.File, None],
name: str,
label: str,
h5_chunks: Union[Tuple, None]
h5_chunks: Union[Tuple, None],
zarr_compressor: Union[Codec, Literal['default']] = 'default'
):
"""Create a zarr dataset from an h5py dataset.
Expand All @@ -41,6 +43,9 @@ def create_zarr_dataset_from_h5_data(
The name of the h5py dataset for error messages.
h5_chunks : tuple
The chunk shape of the h5py dataset.
zarr_compressor : numcodecs.abc.Codec
The codec compressor to use when writing the dataset. If default, the
default compressor will be used.
"""
if h5_dtype is None:
raise Exception(f'No dtype in h5_to_zarr_dataset_prep for dataset {label}')
Expand All @@ -53,6 +58,9 @@ def create_zarr_dataset_from_h5_data(
if h5_data is None:
raise Exception(f'Data must be provided for scalar dataset {label}')

if zarr_compressor != 'default':
raise Exception('zarr_compressor is not supported for scalar datasets')

if _is_numeric_dtype(h5_dtype) or h5_dtype in [bool, np.bool_]:
# Handle the simple numeric types
ds = zarr_parent_group.create_dataset(
Expand Down Expand Up @@ -118,10 +126,13 @@ def create_zarr_dataset_from_h5_data(
shape=h5_shape,
chunks=h5_chunks,
dtype=h5_dtype,
data=h5_data
data=h5_data,
compressor=zarr_compressor
)
elif h5_dtype.kind == 'O':
# For type object, we are going to use the JSON codec
if zarr_compressor != 'default':
raise Exception('zarr_compressor is not supported for object datasets')
if h5_data is not None:
if isinstance(h5_data, h5py.Dataset):
h5_data = h5_data[:]
Expand All @@ -138,6 +149,8 @@ def create_zarr_dataset_from_h5_data(
object_codec=object_codec
)
elif h5_dtype.kind == 'S': # byte string
if zarr_compressor != 'default':
raise Exception('zarr_compressor is not supported for byte string datasets')
if h5_data is None:
raise Exception(f'Data must be provided when converting dataset {label} with dtype {h5_dtype}')
return zarr_parent_group.create_dataset(
Expand All @@ -148,8 +161,12 @@ def create_zarr_dataset_from_h5_data(
data=h5_data
)
elif h5_dtype.kind == 'U': # unicode string
if zarr_compressor != 'default':
raise Exception('zarr_compressor is not supported for unicode string datasets')
raise Exception(f'Array of unicode strings not supported: dataset {label} with dtype {h5_dtype} and shape {h5_shape}')
elif h5_dtype.kind == 'V' and h5_dtype.fields is not None: # compound dtype
if zarr_compressor != 'default':
raise Exception('zarr_compressor is not supported for compound datasets')
if h5_data is None:
raise Exception(f'Data must be provided when converting compound dataset {label}')
h5_data_1d_view = h5_data.ravel()
Expand Down
25 changes: 24 additions & 1 deletion tests/test_zarr_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import h5py
import lindi
import pytest
from utils import assert_groups_equal
import numcodecs
from utils import assert_groups_equal, arrays_are_equal


def test_zarr_write():
Expand Down Expand Up @@ -42,6 +43,28 @@ def test_require_dataset():
h5f_backed_by_zarr.require_dataset('dset_float32', shape=(3,), dtype=np.float64, exact=True)


def test_zarr_write_with_zstd_compressor():
with tempfile.TemporaryDirectory() as tmpdir:
dirname = f'{tmpdir}/test.zarr'
store = zarr.DirectoryStore(dirname)
zarr.group(store=store)
with lindi.LindiH5pyFile.from_zarr_store(store, mode='r+') as h5f_backed_by_zarr:
h5f_backed_by_zarr.create_dataset(
'dset_float32',
data=np.array([1, 2, 3], dtype=np.float32),
compression=numcodecs.Zstd(), # this compressor not supported without plugin in hdf5
)

store2 = zarr.DirectoryStore(dirname)
with lindi.LindiH5pyFile.from_zarr_store(store2) as h5f_backed_by_zarr:
dset = h5f_backed_by_zarr['dset_float32']
assert isinstance(dset, h5py.Dataset)
if not arrays_are_equal(dset[()], np.array([1, 2, 3], dtype=np.float32)):
print(dset[()])
print(np.array([1, 2, 3], dtype=np.float32))
raise Exception('Data mismatch')


def write_example_h5_data(h5f: h5py.File):
h5f.attrs['attr_str'] = 'hello'
h5f.attrs['attr_int'] = 42
Expand Down

0 comments on commit 3c4562f

Please sign in to comment.