diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index bf85664..c130856 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -1,6 +1,9 @@ name: integration tests -on: [push, pull_request] +on: + push: + pull_request: + workflow_dispatch: jobs: build: diff --git a/examples/example1.py b/examples/example1.py index f5c1b6b..cb5fe6f 100644 --- a/examples/example1.py +++ b/examples/example1.py @@ -22,3 +22,9 @@ with pynwb.NWBHDF5IO(file=client, mode="r") as io: nwbfile = io.read() print(nwbfile) + + print('Electrode group at shank0:') + print(nwbfile.electrode_groups["shank0"]) # type: ignore + + print('Electrode group at index 0:') + print(nwbfile.electrodes.group[0]) # type: ignore diff --git a/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py b/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py index d1d1a20..0ca71eb 100644 --- a/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py +++ b/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py @@ -268,7 +268,8 @@ def _get_zgroup_bytes(self, parent_key: str): raise Exception("Store is closed") h5_item = self._h5f.get('/' + parent_key, None) if not isinstance(h5_item, h5py.Group): - raise Exception(f"Item {parent_key} is not a group") + # Important to raise a KeyError here because that's what zarr expects + raise KeyError(f"Item {parent_key} is not a group") # We create a dummy zarr group and then get the .zgroup JSON text # from it. memory_store = MemoryStore() @@ -287,7 +288,8 @@ def _get_zarray_bytes(self, parent_key: str): raise Exception("Store is closed") h5_item = self._h5f.get('/' + parent_key, None) if not isinstance(h5_item, h5py.Dataset): - raise Exception(f"Item {parent_key} is not a dataset") + # Important to raise a KeyError here because that's what zarr expects + raise KeyError(f"Item {parent_key} is not a dataset") # get the shape, chunks, dtype, and filters from the h5 dataset inline_array = self._get_inline_array(parent_key, h5_item) if inline_array.is_inline: diff --git a/lindi/LindiH5pyFile/LindiH5pyDataset.py b/lindi/LindiH5pyFile/LindiH5pyDataset.py index 598279a..83c540a 100644 --- a/lindi/LindiH5pyFile/LindiH5pyDataset.py +++ b/lindi/LindiH5pyFile/LindiH5pyDataset.py @@ -14,11 +14,6 @@ from .LindiH5pyFile import LindiH5pyFile # pragma: no cover -class LindiH5pyDatasetId: - def __init__(self, _h5py_dataset_id): - self._h5py_dataset_id = _h5py_dataset_id - - # This is a global list of external hdf5 clients, which are used by # possibly multiple LindiH5pyFile objects. The key is the URL of the # external hdf5 file, and the value is the h5py.File object. @@ -32,6 +27,9 @@ def __init__(self, _dataset_object: Union[h5py.Dataset, zarr.Array], _file: "Lin self._file = _file self._readonly = _file.mode not in ['r+'] + # see comment in LindiH5pyGroup + self._id = f'{id(self._file)}/{self._dataset_object.name}' + # See if we have the _COMPOUND_DTYPE attribute, which signifies that # this is a compound dtype if isinstance(_dataset_object, zarr.Array): @@ -74,10 +72,8 @@ def __init__(self, _dataset_object: Union[h5py.Dataset, zarr.Array], _file: "Lin @property def id(self): - if isinstance(self._dataset_object, h5py.Dataset): - return LindiH5pyDatasetId(self._dataset_object.id) - else: - return LindiH5pyDatasetId(None) + # see comment in LindiH5pyGroup + return self._id @property def shape(self): # type: ignore diff --git a/lindi/LindiH5pyFile/LindiH5pyFile.py b/lindi/LindiH5pyFile/LindiH5pyFile.py index ab6f775..d107535 100644 --- a/lindi/LindiH5pyFile/LindiH5pyFile.py +++ b/lindi/LindiH5pyFile/LindiH5pyFile.py @@ -26,6 +26,9 @@ def __init__(self, _file_object: Union[h5py.File, zarr.Group], *, _zarr_store: U self._mode: Literal['r', 'r+'] = _mode self._the_group = LindiH5pyGroup(_file_object, self) + # see comment in LindiH5pyGroup + self._id = f'{id(self._file_object)}/' + @staticmethod def from_reference_file_system(rfs: Union[dict, str], mode: Literal["r", "r+"] = "r"): """ @@ -299,7 +302,8 @@ def __contains__(self, name): @property def id(self): - return self._the_group.id + # see comment in LindiH5pyGroup + return self._id @property def file(self): diff --git a/lindi/LindiH5pyFile/LindiH5pyGroup.py b/lindi/LindiH5pyFile/LindiH5pyGroup.py index 02020a2..c9acdaf 100644 --- a/lindi/LindiH5pyFile/LindiH5pyGroup.py +++ b/lindi/LindiH5pyFile/LindiH5pyGroup.py @@ -12,17 +12,23 @@ from .LindiH5pyFile import LindiH5pyFile # pragma: no cover -class LindiH5pyGroupId: - def __init__(self, _h5py_group_id): - self._h5py_group_id = _h5py_group_id - - class LindiH5pyGroup(h5py.Group): def __init__(self, _group_object: Union[h5py.Group, zarr.Group], _file: "LindiH5pyFile"): self._group_object = _group_object self._file = _file self._readonly = _file.mode not in ['r+'] + # In h5py, the id property is an object that exposes low-level + # operations specific to the HDF5 library. LINDI aims to override the + # high-level methods such that the low-level operations on id are not + # needed. However, sometimes packages (e.g., pynwb) use the id as a + # unique identifier for purposes of caching. Therefore, we make the id + # to be a string that is unique for each object. If any of the low-level + # operations are attempted on this id string, then an exception will be + # raised, which will usually indicate that one of the high-level methods + # should be overridden. + self._id = f'{id(self._file)}/{self._group_object.name}' + # The self._write object handles all the writing operations from .writers.LindiH5pyGroupWriter import LindiH5pyGroupWriter # avoid circular import if self._readonly: @@ -132,16 +138,8 @@ def __repr__(self): @property def id(self): - if isinstance(self._group_object, h5py.Group): - return LindiH5pyGroupId(self._group_object.id) - elif isinstance(self._group_object, zarr.Group): - # This is commented out for now because pynwb gets the id of a group - # in at least one place. But that could be avoided in the future, at - # which time, we could uncomment this. - # print('WARNING: Accessing low-level id of LindiH5pyGroup. This should be avoided.') - return LindiH5pyGroupId('') - else: - raise Exception(f'Unexpected group object type: {type(self._group_object)}') + # see comment above + return self._id @property def file(self): diff --git a/lindi/conversion/attr_conversion.py b/lindi/conversion/attr_conversion.py index 15416fa..a3f0a19 100644 --- a/lindi/conversion/attr_conversion.py +++ b/lindi/conversion/attr_conversion.py @@ -23,9 +23,9 @@ def h5_to_zarr_attr(attr: Any, *, label: str = '', h5f: Union[h5py.File, None]): raise Exception(f"Unexpected h5 attribute: None at {label}") elif type(attr) in [int, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64]: return int(attr) - elif type(attr) in [float, np.float16, np.float32, np.float64]: + elif isinstance(attr, (float, np.floating)): return encode_nan_inf_ninf(float(attr)) - elif isinstance(attr, complex) or (isinstance(attr, np.ndarray) and np.issubdtype(attr.dtype, np.complexfloating)): + elif isinstance(attr, (complex, np.complexfloating)): raise Exception(f"Complex number is not supported at {label}") elif type(attr) in [bool, np.bool_]: return bool(attr) diff --git a/lindi/conversion/nan_inf_ninf.py b/lindi/conversion/nan_inf_ninf.py index 3ab012b..cc7595d 100644 --- a/lindi/conversion/nan_inf_ninf.py +++ b/lindi/conversion/nan_inf_ninf.py @@ -21,7 +21,7 @@ def encode_nan_inf_ninf(val): return [encode_nan_inf_ninf(v) for v in val] elif isinstance(val, dict): return {k: encode_nan_inf_ninf(v) for k, v in val.items()} - elif type(val) in [float, np.float16, np.float32, np.float64]: + elif isinstance(val, (float, np.floating)): if np.isnan(val): return 'NaN' elif val == float('inf'): diff --git a/pyproject.toml b/pyproject.toml index ae41614..88d37fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,12 @@ [tool.poetry] name = "lindi" -version = "0.1.1" +version = "0.2.0" description = "" -authors = [] +authors = [ + "Jeremy Magland ", + "Ryan Ly ", + "Oliver Ruebel " +] readme = "README.md" [tool.poetry.dependencies] diff --git a/tests/test_store.py b/tests/test_store.py index aff4667..c2ed53b 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -49,3 +49,7 @@ def test_store(): def _lists_are_equal_as_sets(a, b): return set(a) == set(b) + + +if __name__ == "__main__": + test_store()