Skip to content

Commit

Permalink
Test zarr compressor and filters are correct in parallel case
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored and mergify[bot] committed Sep 16, 2021
1 parent b8326aa commit 28868ef
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
14 changes: 12 additions & 2 deletions sgkit/io/vcfzarr_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,20 @@ def concat_zarrs_optimized(
[group[var] for group in zarr_groups], dtype=dtype
)

_to_zarr_kwargs = dict(fill_value=None)
_to_zarr_kwargs = dict(
compressor=first_zarr_group[var].compressor,
filters=first_zarr_group[var].filters,
fill_value=None,
)
if not fix_strings and arr.dtype == "O":
# We assume that all object dtypes are variable length strings
_to_zarr_kwargs["object_codec"] = numcodecs.VLenUTF8()
var_len_str_codec = numcodecs.VLenUTF8()
_to_zarr_kwargs["object_codec"] = var_len_str_codec
# Remove from filters to avoid double encoding error
if var_len_str_codec in first_zarr_group[var].filters:
filters = list(first_zarr_group[var].filters)
filters.remove(var_len_str_codec)
_to_zarr_kwargs["filters"] = filters

d = _to_zarr( # type: ignore[no-untyped-call]
arr,
Expand Down
37 changes: 37 additions & 0 deletions sgkit/tests/io/vcf/test_vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,43 @@ def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path):
assert z["variant_id_mask"].filters is None


@pytest.mark.parametrize(
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__parallel_compressor_and_filters(
shared_datadir, is_path, tmp_path
):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
regions = ["20", "21"]

default_compressor = Blosc("zlib", 1, Blosc.NOSHUFFLE)
variant_id_compressor = Blosc("zlib", 2, Blosc.NOSHUFFLE)
encoding = dict(
variant_id=dict(compressor=variant_id_compressor),
variant_id_mask=dict(filters=None),
)
vcf_to_zarr(
path,
output,
regions=regions,
chunk_length=5_000,
compressor=default_compressor,
encoding=encoding,
)

# look at actual Zarr store to check compressor and filters
z = zarr.open(output)
assert z["call_genotype"].compressor == default_compressor
assert z["call_genotype"].filters is None
assert z["call_genotype_mask"].filters == [PackBits()]

assert z["variant_id"].compressor == variant_id_compressor
assert z["variant_id_mask"].filters is None


@pytest.mark.parametrize(
"is_path",
[True, False],
Expand Down

0 comments on commit 28868ef

Please sign in to comment.