Skip to content

Commit

Permalink
handle and test of mode="a"
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Oct 5, 2023
1 parent 9700c57 commit ac1d150
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 30 deletions.
63 changes: 38 additions & 25 deletions python/kvikio/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import numpy as np
import zarr
import zarr.creation
import zarr.errors
import zarr.storage
import zarr.util
from numcodecs.abc import Codec
from numcodecs.compat import ensure_contiguous_ndarray_like
from numcodecs.registry import register_codec
Expand Down Expand Up @@ -360,35 +362,46 @@ def open_cupy_array(
if not hasattr(meta_array, "__cuda_array_interface__"):
raise ValueError("meta_array must implement __cuda_array_interface__")

if mode in ("r", "r+"):
ret = zarr.open_array(
store=kvikio.zarr.GDSStore(path=store),
mode=mode,
meta_array=meta_array,
**kwargs,
)
# If we are reading a LZ4-CPU compressed file, we overwrite the metadata
# on-the-fly to make Zarr use LZ4-GPU for both compression and decompression.
compat_lz4 = CompatCompressor.lz4()
if ret.compressor == compat_lz4.cpu:
if mode in ("r", "r+", "a"):
# In order to handle "a", we start by trying to open the file in read mode.
try:
ret = zarr.open_array(
store=kvikio.zarr.GDSStore(
path=store,
compressor_config_overwrite=compat_lz4.cpu.get_config(),
decompressor_config_overwrite=compat_lz4.gpu.get_config(),
),
mode=mode,
store=kvikio.zarr.GDSStore(path=store),
mode="r+",
meta_array=meta_array,
**kwargs,
)
elif not isinstance(ret.compressor, CudaCodec):
raise ValueError(
"The Zarr file was written using a non-CUDA compatible "
f"compressor, {ret.compressor}, please use something "
"like kvikio.zarr.CompatCompressor"
)
return ret

except (zarr.errors.ContainsGroupError, zarr.errors.ArrayNotFoundError):
# If the array doesn't exist, we re-raise the error when reading ("r", "r+")
# and continue when appendding ("a").
if mode in ("r", "r+"):
raise
else:
# If we were able to read the file without error, we handle CPU/GPU mixing.
# If we are reading a LZ4-CPU compressed file, we overwrite the
# metadata on-the-fly to make Zarr use LZ4-GPU for both compression
# and decompression.
compat_lz4 = CompatCompressor.lz4()
if ret.compressor == compat_lz4.cpu:
ret = zarr.open_array(
store=kvikio.zarr.GDSStore(
path=store,
compressor_config_overwrite=compat_lz4.cpu.get_config(),
decompressor_config_overwrite=compat_lz4.gpu.get_config(),
),
mode=mode,
meta_array=meta_array,
**kwargs,
)
elif not isinstance(ret.compressor, CudaCodec):
raise ValueError(
"The Zarr file was written using a non-CUDA compatible "
f"compressor, {ret.compressor}, please use something "
"like kvikio.zarr.CompatCompressor"
)
return ret

# At this point, we known that we are writing a new array ("w", "w-", "a")
if isinstance(compressor, CompatCompressor):
compressor_config_overwrite = compressor.cpu.get_config()
decompressor_config_overwrite = compressor.gpu.get_config()
Expand Down
12 changes: 7 additions & 5 deletions python/tests/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,13 @@ def test_compressor_config_overwrite(tmp_path, xp, algo):
numpy.testing.assert_array_equal(z[:], range(10))


def test_open_cupy_array(tmp_path):
@pytest.mark.parametrize("write_mode", ["w", "w-", "a"])
@pytest.mark.parametrize("read_mode", ["r", "r+", "a"])
def test_open_cupy_array(tmp_path, write_mode, read_mode):
a = cupy.arange(10)
z = kvikio_zarr.open_cupy_array(
tmp_path,
mode="w",
mode=write_mode,
shape=a.shape,
dtype=a.dtype,
chunks=(2,),
Expand All @@ -231,23 +233,23 @@ def test_open_cupy_array(tmp_path):

z = kvikio_zarr.open_cupy_array(
tmp_path,
mode="r",
mode=read_mode,
)
assert a.shape == z.shape
assert a.dtype == z.dtype
assert isinstance(z[:], type(a))
assert z.compressor == kvikio_nvcomp_codec.NvCompBatchCodec("lz4")
cupy.testing.assert_array_equal(a, z[:])

z = zarr.open_array(tmp_path, mode="r")
z = zarr.open_array(tmp_path, mode=read_mode)
assert a.shape == z.shape
assert a.dtype == z.dtype
assert isinstance(z[:], numpy.ndarray)
assert z.compressor == kvikio_zarr.CompatCompressor.lz4().cpu
numpy.testing.assert_array_equal(a.get(), z[:])


@pytest.mark.parametrize("mode", ["r", "r+"])
@pytest.mark.parametrize("mode", ["r", "r+", "a"])
def test_open_cupy_array_incompatible_compressor(tmp_path, mode):
zarr.create((10,), store=tmp_path, compressor=numcodecs.Blosc())

Expand Down

0 comments on commit ac1d150

Please sign in to comment.