Skip to content

Commit

Permalink
handle compound dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Mar 18, 2024
1 parent 8a6ef32 commit 9eeb63c
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 5 deletions.
88 changes: 86 additions & 2 deletions lindi/LindiClient/LindiDataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
from typing import Dict, Any
import numpy as np
import zarr
import h5py
import remfile
Expand All @@ -10,6 +11,17 @@ def __init__(self, *, _zarr_array: zarr.Array):
self._zarr_array = _zarr_array
self._is_scalar = self._zarr_array.attrs.get("_SCALAR", False)

# See if we have the _COMPOUND_DTYPE attribute, which signifies that
# this is a compound dtype
compound_dtype_obj = self._zarr_array.attrs.get("_COMPOUND_DTYPE", None)
if compound_dtype_obj is not None:
# If we have a compound dtype, then create the numpy dtype
self._compound_dtype = np.dtype(
[(compound_dtype_obj[i][0], compound_dtype_obj[i][1]) for i in range(len(compound_dtype_obj))]
)
else:
self._compound_dtype = None

self._external_hdf5_clients: Dict[str, h5py.File] = {}

@property
Expand All @@ -35,6 +47,8 @@ def shape(self):

@property
def dtype(self):
if self._compound_dtype is not None:
return self._compound_dtype
return self._zarr_array.dtype

@property
Expand Down Expand Up @@ -74,6 +88,28 @@ def __getitem__(self, selection):
dataset = client[name]
assert isinstance(dataset, h5py.Dataset)
return dataset[selection]
if self._compound_dtype is not None:
# Compound dtype
# In this case we index into the compound dtype using the name of the field
# For example, if the dtype is [('x', 'f4'), ('y', 'f4')], then we can do
# dataset['x'][0] to get the first x value
assert self._compound_dtype.names is not None
if isinstance(selection, str):
# Find the index of this field in the compound dtype
ind = self._compound_dtype.names.index(selection)
# Get the dtype of this field
dtype = np.dtype(self._compound_dtype[ind])
# Return a new object that can be sliced further
# It's important that the return type is Any here, because otherwise we get linter problems
ret: Any = LindiDatasetCompoundFieldSelection(
dataset=self, ind=ind, dtype=dtype
)
return ret
else:
raise TypeError(
f"Compound dataset {self.name} does not support selection with {selection}"
)

# We use zarr's slicing, except in the case of a scalar dataset
if self.ndim == 0:
# make sure selection is ()
Expand All @@ -85,5 +121,53 @@ def __getitem__(self, selection):
def _get_external_hdf5_client(self, url: str) -> h5py.File:
if url not in self._external_hdf5_clients:
remf = remfile.File(url)
self._external_hdf5_clients[url] = h5py.File(remf, 'r')
self._external_hdf5_clients[url] = h5py.File(remf, "r")
return self._external_hdf5_clients[url]


class LindiDatasetCompoundFieldSelection:
"""
This class is returned when a compound dataset is indexed with a field name.
For example, if the dataset has dtype [('x', 'f4'), ('y', 'f4')], then we
can do dataset['x'][0] to get the first x value. The dataset['x'] returns an
object of this class.
"""
def __init__(self, *, dataset: LindiDataset, ind: int, dtype: np.dtype):
self._dataset = dataset # The parent dataset
self._ind = ind # The index of the field in the compound dtype
self._dtype = dtype # The dtype of the field
if self._dataset.ndim != 1:
# For now we only support 1D datasets
raise TypeError(
f"Compound field selection only implemented for 1D datasets, not {self._dataset.ndim}D"
)
# Prepare the data in memory
za = self._dataset._zarr_array
d = [za[i][self._ind] for i in range(len(za))]
self._data = np.array(d, dtype=self._dtype)

def __len__(self):
return self._dataset._zarr_array.shape[0]

def __iter__(self):
for i in range(len(self)):
yield self[i]

@property
def ndim(self):
return self._dataset._zarr_array.ndim

@property
def shape(self):
return self._dataset._zarr_array.shape

@property
def dtype(self):
self._dtype

@property
def size(self):
return self._data.size

def __getitem__(self, selection):
return self._data[selection]
9 changes: 8 additions & 1 deletion lindi/LindiH5Store/LindiH5Store.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ def _get_zattrs_bytes(self, parent_key: str):
if isinstance(h5_item, h5py.Dataset):
if h5_item.ndim == 0:
dummy_group.attrs["_SCALAR"] = True
if h5_item.dtype.kind == "V": # compound type
compound_dtype = [
[name, str(h5_item.dtype[name])]
for name in h5_item.dtype.names
]
# For example: [['x', 'uint32'], ['y', 'uint32'], ['weight', 'float32']]
dummy_group.attrs["_COMPOUND_DTYPE"] = compound_dtype
external_array_link = self._get_external_array_link(parent_key, h5_item)
if external_array_link is not None:
dummy_group.attrs["_EXTERNAL_ARRAY_LINK"] = external_array_link
Expand Down Expand Up @@ -506,7 +513,7 @@ def _reformat_json(x: Union[bytes, None]) -> Union[bytes, None]:
if x is None:
return None
a = json.loads(x.decode("utf-8"))
return json.dumps(a, cls=FloatJSONEncoder).encode("utf-8")
return json.dumps(a, cls=FloatJSONEncoder, separators=(",", ":")).encode("utf-8")


# From https://github.com/rly/h5tojson/blob/b162ff7f61160a48f1dc0026acb09adafdb422fa/h5tojson/h5tojson.py#L121-L156
Expand Down
52 changes: 50 additions & 2 deletions lindi/LindiH5Store/_zarr_info_for_h5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _zarr_info_for_h5_dataset(h5_dataset: h5py.Dataset) -> ZarrInfoForH5Dataset:
filters=None,
fill_value=' ',
object_codec=numcodecs.JSON(),
inline_data=json.dumps([value, '|O', [1]]).encode('utf-8')
inline_data=json.dumps([value, '|O', [1]], separators=(',', ':')).encode('utf-8')
)
else:
raise Exception(f'Not yet implemented (1): object scalar dataset with value {value} and dtype {dtype}')
Expand Down Expand Up @@ -124,7 +124,7 @@ def _zarr_info_for_h5_dataset(h5_dataset: h5py.Dataset) -> ZarrInfoForH5Dataset:
data_vec_view[i] = None
else:
raise Exception(f'Cannot handle dataset {h5_dataset.name} with dtype {dtype} and shape {shape}')
inline_data = json.dumps(data.tolist() + ['|O', list(shape)]).encode('utf-8')
inline_data = json.dumps(data.tolist() + ['|O', list(shape)], separators=(',', ':')).encode('utf-8')
return ZarrInfoForH5Dataset(
shape=shape,
chunks=shape, # be explicit about chunks
Expand All @@ -136,10 +136,58 @@ def _zarr_info_for_h5_dataset(h5_dataset: h5py.Dataset) -> ZarrInfoForH5Dataset:
)
elif dtype.kind in 'SU': # byte string or unicode string
raise Exception(f'Not yet implemented (2): dataset {h5_dataset.name} with dtype {dtype} and shape {shape}')
elif dtype.kind == 'V': # void (i.e. compound)
# This is an array representing the compound type
# For example: [['x', 'uint32'], ['y', 'uint32'], ['weight', 'float32']]
compound_dtype = [
[name, str(dtype[name])]
for name in dtype.names
]
if h5_dataset.ndim == 1:
# for now we only handle the case of a 1D compound dataset
data = h5_dataset[:]
# Create an array that would be for example like this
# [[3, 4, 5.3], [2, 1, 7.1], ...]
# where the first entry corresponds to x in the example above, the second to y, and the third to weight
# This is a more compact representation than [{'x': ...}]
# The _COMPOUND_DTYPE attribute will be set on the dataset in the zarr store
# which will be used to interpret the data
array_list = [
[
_json_serialize(data[name][i], type_str)
for name, type_str in compound_dtype
]
for i in range(h5_dataset.shape[0])
]
object_codec = numcodecs.JSON()
inline_data = array_list + ['|O', list(shape)]
return ZarrInfoForH5Dataset(
shape=shape,
chunks=shape, # be explicit about chunks
dtype='object',
filters=None,
fill_value=' ', # not sure what to put here
object_codec=object_codec,
inline_data=json.dumps(inline_data, separators=(',', ':')).encode('utf-8')
)
else:
raise Exception(f'More than one dimension not supported for compound dataset {h5_dataset.name} with dtype {dtype} and shape {shape}')
else:
print(dtype.kind)
raise Exception(f'Not yet implemented (3): dataset {h5_dataset.name} with dtype {dtype} and shape {shape}')


def _json_serialize(val: Any, type_str: str) -> Any:
if type_str.startswith('uint'):
return int(val)
elif type_str.startswith('int'):
return int(val)
elif type_str.startswith('float'):
return float(val)
else:
raise Exception(f'Unable to serialize {val} with type {type_str}')


def _get_numeric_format_str(dtype: Any) -> Union[str, None]:
"""Get the format string for a numeric dtype.
Expand Down
33 changes: 33 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,39 @@ def test_numpy_array_of_strings():
raise ValueError("Arrays are not equal")


def test_compound_dtype():
print("Testing compound dtype")
with tempfile.TemporaryDirectory() as tmpdir:
filename = f"{tmpdir}/test.h5"
with h5py.File(filename, "w") as f:
dt = np.dtype([("x", "i4"), ("y", "f8")])
f.create_dataset("X", data=[(1, 3.14), (2, 6.28)], dtype=dt)
h5f = h5py.File(filename, "r")
store = LindiH5Store.from_file(filename, url=filename)
rfs = store.to_reference_file_system()
client = LindiClient.from_reference_file_system(rfs)
X1 = h5f["X"]
assert isinstance(X1, h5py.Dataset)
X2 = client["X"]
assert isinstance(X2, LindiDataset)
assert X1.shape == X2.shape
assert X1.dtype == X2.dtype
assert X1.size == X2.size
# assert X1.nbytes == X2.nbytes # nbytes are not going to match because the internal representation is different
assert len(X1) == len(X2)
if not _check_equal(X1['x'][:], X2['x'][:]):
print("WARNING. Arrays for x are not equal")
print(X1['x'][:])
print(X2['x'][:])
raise ValueError("Arrays are not equal")
if not _check_equal(X1['y'][:], X2['y'][:]):
print("WARNING. Arrays for y are not equal")
print(X1['y'][:])
print(X2['y'][:])
raise ValueError("Arrays are not equal")
store.close()


def test_attributes():
print("Testing attributes")
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down

0 comments on commit 9eeb63c

Please sign in to comment.