diff --git a/python/kvikio/zarr.py b/python/kvikio/zarr.py index b6f8154b72..0b174316eb 100644 --- a/python/kvikio/zarr.py +++ b/python/kvikio/zarr.py @@ -14,7 +14,9 @@ import numpy as np import zarr import zarr.creation +import zarr.errors import zarr.storage +import zarr.util from numcodecs.abc import Codec from numcodecs.compat import ensure_contiguous_ndarray_like from numcodecs.registry import register_codec @@ -360,35 +362,46 @@ def open_cupy_array( if not hasattr(meta_array, "__cuda_array_interface__"): raise ValueError("meta_array must implement __cuda_array_interface__") - if mode in ("r", "r+"): - ret = zarr.open_array( - store=kvikio.zarr.GDSStore(path=store), - mode=mode, - meta_array=meta_array, - **kwargs, - ) - # If we are reading a LZ4-CPU compressed file, we overwrite the metadata - # on-the-fly to make Zarr use LZ4-GPU for both compression and decompression. - compat_lz4 = CompatCompressor.lz4() - if ret.compressor == compat_lz4.cpu: + if mode in ("r", "r+", "a"): + # In order to handle "a", we start by trying to open the file in read mode. + try: ret = zarr.open_array( - store=kvikio.zarr.GDSStore( - path=store, - compressor_config_overwrite=compat_lz4.cpu.get_config(), - decompressor_config_overwrite=compat_lz4.gpu.get_config(), - ), - mode=mode, + store=kvikio.zarr.GDSStore(path=store), + mode="r+", meta_array=meta_array, **kwargs, ) - elif not isinstance(ret.compressor, CudaCodec): - raise ValueError( - "The Zarr file was written using a non-CUDA compatible " - f"compressor, {ret.compressor}, please use something " - "like kvikio.zarr.CompatCompressor" - ) - return ret - + except (zarr.errors.ContainsGroupError, zarr.errors.ArrayNotFoundError): + # If the array doesn't exist, we re-raise the error when reading ("r", "r+") + # and continue when appendding ("a"). + if mode in ("r", "r+"): + raise + else: + # If we were able to read the file without error, we handle CPU/GPU mixing. + # If we are reading a LZ4-CPU compressed file, we overwrite the + # metadata on-the-fly to make Zarr use LZ4-GPU for both compression + # and decompression. + compat_lz4 = CompatCompressor.lz4() + if ret.compressor == compat_lz4.cpu: + ret = zarr.open_array( + store=kvikio.zarr.GDSStore( + path=store, + compressor_config_overwrite=compat_lz4.cpu.get_config(), + decompressor_config_overwrite=compat_lz4.gpu.get_config(), + ), + mode=mode, + meta_array=meta_array, + **kwargs, + ) + elif not isinstance(ret.compressor, CudaCodec): + raise ValueError( + "The Zarr file was written using a non-CUDA compatible " + f"compressor, {ret.compressor}, please use something " + "like kvikio.zarr.CompatCompressor" + ) + return ret + + # At this point, we known that we are writing a new array ("w", "w-", "a") if isinstance(compressor, CompatCompressor): compressor_config_overwrite = compressor.cpu.get_config() decompressor_config_overwrite = compressor.gpu.get_config() diff --git a/python/tests/test_zarr.py b/python/tests/test_zarr.py index cc0ee0ebdd..a780263bee 100644 --- a/python/tests/test_zarr.py +++ b/python/tests/test_zarr.py @@ -212,11 +212,13 @@ def test_compressor_config_overwrite(tmp_path, xp, algo): numpy.testing.assert_array_equal(z[:], range(10)) -def test_open_cupy_array(tmp_path): +@pytest.mark.parametrize("write_mode", ["w", "w-", "a"]) +@pytest.mark.parametrize("read_mode", ["r", "r+", "a"]) +def test_open_cupy_array(tmp_path, write_mode, read_mode): a = cupy.arange(10) z = kvikio_zarr.open_cupy_array( tmp_path, - mode="w", + mode=write_mode, shape=a.shape, dtype=a.dtype, chunks=(2,), @@ -231,7 +233,7 @@ def test_open_cupy_array(tmp_path): z = kvikio_zarr.open_cupy_array( tmp_path, - mode="r", + mode=read_mode, ) assert a.shape == z.shape assert a.dtype == z.dtype @@ -239,7 +241,7 @@ def test_open_cupy_array(tmp_path): assert z.compressor == kvikio_nvcomp_codec.NvCompBatchCodec("lz4") cupy.testing.assert_array_equal(a, z[:]) - z = zarr.open_array(tmp_path, mode="r") + z = zarr.open_array(tmp_path, mode=read_mode) assert a.shape == z.shape assert a.dtype == z.dtype assert isinstance(z[:], numpy.ndarray) @@ -247,7 +249,7 @@ def test_open_cupy_array(tmp_path): numpy.testing.assert_array_equal(a.get(), z[:]) -@pytest.mark.parametrize("mode", ["r", "r+"]) +@pytest.mark.parametrize("mode", ["r", "r+", "a"]) def test_open_cupy_array_incompatible_compressor(tmp_path, mode): zarr.create((10,), store=tmp_path, compressor=numcodecs.Blosc())