diff --git a/pyn5/dataset.py b/pyn5/dataset.py index 38f7124..07bd6e6 100644 --- a/pyn5/dataset.py +++ b/pyn5/dataset.py @@ -1,8 +1,11 @@ +from copy import copy + from typing import Union, Tuple, Optional, Any import numpy as np from h5py_like import DatasetBase, AttributeManagerBase, mutation, Name +from h5py_like.shape_utils import thread_read_fn, thread_write_fn from pyn5.attributes import AttributeManager from .pyn5 import ( DatasetUINT8, @@ -33,6 +36,8 @@ class Dataset(DatasetBase): + threads = None + def __init__(self, name: str, parent: "Group"): # noqa would need circular imports """ @@ -44,6 +49,7 @@ def __init__(self, name: str, parent: "Group"): # noqa would need circular impo self._parent = parent self._path = self.parent._path / name self._attrs = AttributeManager.from_parent(self) + self.threads = copy(self.threads) attrs = self._attrs._read_attributes() @@ -90,20 +96,46 @@ def resize(self, size: Union[int, Tuple[int, ...]], axis: Optional[int] = None): raise NotImplementedError() def __getitem__(self, args) -> np.ndarray: - def fn(translation, dimensions): + def inner_fn(translation, dimensions): return self._impl.read_ndarray( translation[::-1], dimensions[::-1] ).transpose() + if self.threads: + def fn(translation, dimensions): + return thread_read_fn( + translation, + dimensions, + self.chunks, + self.shape, + inner_fn, + self.threads + ) + else: + fn = inner_fn + return self._getitem(args, fn, self._astype) @mutation def __setitem__(self, args, val): - def fn(offset, arr): + def inner_fn(offset, arr): return self._impl.write_ndarray( offset[::-1], arr.transpose(), self.fillvalue ) + if self.threads: + def fn(offset, arr): + return thread_write_fn( + offset, + arr, + self.chunks, + self.shape, + inner_fn, + self.threads + ) + else: + fn = inner_fn + return self._setitem(args, val, fn) @property diff --git a/requirements_dev.txt b/requirements_dev.txt index 114b51e..55468ae 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -11,5 +11,5 @@ setuptools_rust==0.10.6 numpy==1.16.4 pytest==4.6.3 pytest-runner==5.1 -h5py_like==0.4.0 +h5py_like==0.5.2 h5py==2.9.0 diff --git a/setup.py b/setup.py index 2436020..e8f2163 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ with open("HISTORY.rst") as history_file: history = history_file.read() -requirements = ["numpy", "h5py_like>=0.4.0"] +requirements = ["numpy", "h5py_like>=0.5.2"] setup_requirements = [] test_requirements = [] diff --git a/tests/test_h5_like.py b/tests/test_h5_like.py index d5ca19b..c9276b8 100644 --- a/tests/test_h5_like.py +++ b/tests/test_h5_like.py @@ -8,13 +8,15 @@ from pathlib import Path from h5py_like import Mode, FileMixin -from h5py_like.test_utils import FileTestBase, DatasetTestBase, GroupTestBase, ModeTestBase +from h5py_like.test_utils import ( + FileTestBase, ThreadedDatasetTestBase, GroupTestBase, ModeTestBase, +) from pyn5 import File from .common import blocks_hash from .common import blocks_in, attrs_in -ds_kwargs = deepcopy(DatasetTestBase.dataset_kwargs) +ds_kwargs = deepcopy(ThreadedDatasetTestBase.dataset_kwargs) ds_kwargs["chunks"] = (5, 5, 5) @@ -28,8 +30,8 @@ class TestGroup(GroupTestBase): pass -class TestDataset(DatasetTestBase): - dataset_kwargs = ds_kwargs +class TestDataset(ThreadedDatasetTestBase): + dataset_kwargs = deepcopy(ThreadedDatasetTestBase.dataset_kwargs) def test_has_metadata(self, file_): ds = self.dataset(file_)