diff --git a/src/fibsem_tools/io/h5.py b/src/fibsem_tools/io/h5.py index 30801d0..1d877b7 100644 --- a/src/fibsem_tools/io/h5.py +++ b/src/fibsem_tools/io/h5.py @@ -7,39 +7,69 @@ H5_ACCESS_MODES = ("r", "r+", "w", "w-", "x", "a") -H5_DATASET_KWDS = ("name", - "shape", - "dtype", - "data", - "chunks", - "compression", - "compression_opts", - "scaleoffset", - "shuffle", - "fletcher32", - "maxshape", - "fillvalue", - "track_times", - "track_order", - "external", - "allow_unknown_filter") - -H5_GROUP_KWDS = ("name", - "track_order") - -H5_FILE_KWDS = ("name", - "mode", - "driver", - "libver", - "userblock_size", - "swmr", - "rdcc_nslots", - "rdcc_nbytes", - "rdcc_w0", - "track_order", - "fs_strategy", - "fs_persist", - "fs_threshold") +H5_DATASET_KWDS = ( + "name", + "shape", + "dtype", + "data", + "chunks", + "compression", + "compression_opts", + "scaleoffset", + "shuffle", + "fletcher32", + "maxshape", + "fillvalue", + "track_times", + "track_order", + "dcpl", + "external", + "allow_unknown_filter", +) + +H5_GROUP_KWDS = ("name", "track_order") + +H5_FILE_KWDS = ( + "name", + "mode", + "driver", + "libver", + "userblock_size", + "swmr", + "rdcc_nslots", + "rdcc_nbytes", + "rdcc_w0", + "track_order", + "fs_strategy", + "fs_persist", + "fs_threshold", +) + + +# Could use multiple inheritance here +class ManagedDataset(h5py.Dataset): + """ + h5py.Dataset with context manager behavior + """ + + def __enter__(self): + return self + + def __exit__(self, ex_type, ex_value, ex_traceback): + self.file.close() + + +class ManagedGroup(h5py.Group): + """ + h5py.Group with context manager behavior + """ + + def __enter__(self): + return self + + def __exit__(self, ex_type, ex_value, ex_traceback): + self.file.close() + def partition_h5_kwargs(**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ @@ -47,38 +77,46 @@ def partition_h5_kwargs(**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ file_kwargs = kwargs.copy() dataset_kwargs = {} - for key in H5_DATASET_KWDS: - if key in file_kwargs: + for key in kwargs: + if key in H5_DATASET_KWDS: dataset_kwargs[key] = file_kwargs.pop(key) return file_kwargs, dataset_kwargs def access_h5( - store: Pathlike, path: Pathlike, mode: str, **kwargs + store: Union[h5py.File, Pathlike], path: Pathlike, **kwargs ) -> Union[h5py.Dataset, h5py.Group]: """ Docstring """ - if mode not in H5_ACCESS_MODES: - raise ValueError(f"Invalid access mode. Got {mode}, expected one of {H5_ACCESS_MODES}.") - attrs = kwargs.pop("attrs", {}) + mode = kwargs.get("mode", "r") file_kwargs, dataset_kwargs = partition_h5_kwargs(**kwargs) - - h5f = h5py.File(store, mode=mode, **file_kwargs) - if mode in ("r", "r+", "a") and (result := h5f.get(path)) is not None: - return result + if isinstance(store, h5py.File): + h5f = store + else: + h5f = h5py.File(store, **file_kwargs) + + if mode in ("r", "r+", "a"): + # let h5py handle keyerrors + result = h5f[path] else: if len(dataset_kwargs) > 0: - if 'name' in dataset_kwargs: - warnings.warn('"Name" was provided to this function as a keyword argument. This value will be replaced with the second argument to this function.') + if "name" in dataset_kwargs: + warnings.warn( + '"Name" was provided to this function as a keyword argument. This value will be replaced with the second argument to this function.' + ) dataset_kwargs["name"] = path result = h5f.create_dataset(**dataset_kwargs) else: result = h5f.require_group(path) - result.attrs.update(**attrs) - return result + if isinstance(result, h5py.Group): + result = ManagedGroup(result.id) + else: + result = ManagedDataset(result.id) + + return result diff --git a/tests/test_h5.py b/tests/test_h5.py new file mode 100644 index 0000000..affc4c1 --- /dev/null +++ b/tests/test_h5.py @@ -0,0 +1,20 @@ +from h5py._hl.dataset import make_new_dset +from fibsem_tools.io.h5 import partition_h5_kwargs +from inspect import signature, Parameter + + +def test_kwarg_partition(): + dataset_creation_sig = signature(make_new_dset) + dataset_kwargs = { + k: None + for k, v in filter( + lambda p: p[1].default is not Parameter.empty, + dataset_creation_sig.parameters.items(), + ) + } + file_kwargs = {"foo": None, "bar": None} + file_kwargs_out, dataset_kwargs_out = partition_h5_kwargs( + **dataset_kwargs, **file_kwargs + ) + assert file_kwargs == file_kwargs_out + assert dataset_kwargs == dataset_kwargs_out diff --git a/tests/test_storage.py b/tests/test_storage.py index 3509665..6c6ea84 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -110,15 +110,20 @@ def test_access_array_h5(): data = np.random.randint(0, 255, size=(10, 10, 10), dtype="uint8") attrs = {"resolution": "1000"} with tempfile.TemporaryFile(suffix=".h5") as store: - arr = access_h5(store, key, data=data, attrs=attrs, mode="w") - assert dict(arr.attrs) == attrs - assert np.array_equal(arr[:], data) - arr.file.close() + with access_h5(store, key, data=data, attrs=attrs, mode="w") as arr1: + assert dict(arr1.attrs) == attrs + assert np.array_equal(arr1[:], data) - arr2 = access_h5(store, key, mode="r") - assert dict(arr2.attrs) == attrs - assert np.array_equal(arr2[:], data) - arr2.file.close() + with access_h5(store, key, mode="r") as arr2: + assert dict(arr2.attrs) == attrs + assert np.array_equal(arr2[:], data) + + with access_h5(store, key, mode="r") as arr3: + h5d = arr3.file[key] + assert h5d.shape == arr3.shape + assert h5d.attrs == arr3.attrs + assert h5d.chunks == arr3.chunks + assert h5d.compression == arr3.compression def test_access_group_h5(): @@ -126,13 +131,15 @@ def test_access_group_h5(): attrs = {"resolution": "1000"} with tempfile.TemporaryFile(suffix=".h5") as store: - grp = access_h5(store, key, attrs=attrs, mode="w") - assert dict(grp.attrs) == attrs - grp.file.close() + with access_h5(store, key, attrs=attrs, mode="w") as grp1: + assert dict(grp1.attrs) == attrs + + with access_h5(store, key, mode="r") as grp2: + assert dict(grp2.attrs) == attrs - grp2 = access_h5(store, key, mode="r") - assert dict(grp2.attrs) == attrs - grp2.file.close() + with access_h5(store, key, mode="r") as grp3: + h5g = grp3.file[key] + assert h5g.attrs == grp3.attrs def test_list_files():