Skip to content

Commit

Permalink
Add concat algorithm parameter to vcf_to_zarr (#365)
Browse files Browse the repository at this point in the history
Use variable-length strings for storing alleles in Zarr #643
  • Loading branch information
tomwhite authored and mergify[bot] committed Sep 16, 2021
1 parent 19f0b1b commit 45f1267
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 31 deletions.
88 changes: 80 additions & 8 deletions sgkit/io/vcf/vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@
import xarray as xr
from cyvcf2 import VCF, Variant
from numcodecs import PackBits
from typing_extensions import Literal

from sgkit import variables
from sgkit.io.dataset import load_dataset
from sgkit.io.utils import zarrs_to_dataset
from sgkit.io.vcf import partition_into_regions
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename
from sgkit.io.vcfzarr_reader import vcf_number_to_dimension_and_size
from sgkit.io.vcfzarr_reader import (
concat_zarrs_optimized,
vcf_number_to_dimension_and_size,
)
from sgkit.model import (
DIM_PLOIDY,
DIM_SAMPLE,
Expand Down Expand Up @@ -529,6 +533,7 @@ def vcf_to_zarr_parallel(
fields: Optional[Sequence[str]] = None,
exclude_fields: Optional[Sequence[str]] = None,
field_defs: Optional[Dict[str, Dict[str, Any]]] = None,
concat_algorithm: Optional[Literal["xarray_internal"]] = None,
) -> None:
"""Convert specified regions of one or more VCF files to zarr files, then concat, rechunk, write to zarr"""

Expand Down Expand Up @@ -557,11 +562,15 @@ def vcf_to_zarr_parallel(
field_defs=field_defs,
)

ds = zarrs_to_dataset(paths, chunk_length, chunk_width, tempdir_storage_options)

# Ensure Dask task graph is efficient, see https://github.com/dask/dask/issues/5105
with dask.config.set({"optimization.fuse.ave-width": dask_fuse_avg_width}):
ds.to_zarr(output, mode="w")
concat_zarrs(
paths,
output,
concat_algorithm=concat_algorithm,
chunk_length=chunk_length,
chunk_width=chunk_width,
storage_options=tempdir_storage_options,
dask_fuse_avg_width=dask_fuse_avg_width,
)


def vcf_to_zarrs(
Expand Down Expand Up @@ -703,6 +712,64 @@ def vcf_to_zarrs(
return parts


def concat_zarrs(
urls: Sequence[str],
output: Union[PathType, MutableMapping[str, bytes]],
*,
concat_algorithm: Optional[Literal["xarray_internal"]] = None,
chunk_length: int = 10_000,
chunk_width: int = 1_000,
storage_options: Optional[Dict[str, str]] = None,
dask_fuse_avg_width: int = 50,
) -> None:
"""Concatenate multiple Zarr stores into a single Zarr store.
The Zarr stores are concatenated and rechunked to produce a single combined store.
Parameters
----------
urls
A list of URLs to the Zarr stores to combine, typically the return value of
:func:`vcf_to_zarrs`.
output
Zarr store or path to directory in file system.
concat_algorithm
The algorithm to use to concatenate and rechunk Zarr files. The default None means
use the optimized version suitable for large files, whereas ``xarray_internal`` will
use built-in Xarray APIs, which can exhibit high memory usage, see https://github.com/dask/dask/issues/6745.
chunk_length
Length (number of variants) of chunks in which data are stored, by default 10,000.
This is only used when ``concat_algorithm`` is ``xarray_internal``.
chunk_width
Width (number of samples) to use when storing chunks in output, by default 1,000.
This is only used when ``concat_algorithm`` is ``xarray_internal``.
storage_options
Any additional parameters for the storage backend (see ``fsspec.open``).
dask_fuse_avg_width
Setting for Dask's ``optimization.fuse.ave-width``, see https://github.com/dask/dask/issues/5105
"""
if concat_algorithm == "xarray_internal":
ds = zarrs_to_dataset(urls, chunk_length, chunk_width, storage_options)

with dask.config.set({"optimization.fuse.ave-width": dask_fuse_avg_width}):
ds.to_zarr(output, mode="w")
else:

vars_to_rechunk = []
vars_to_copy = []
storage_options = storage_options or {}
ds = xr.open_zarr( # type: ignore[no-untyped-call]
fsspec.get_mapper(urls[0], **storage_options), concat_characters=False
)
for (var, arr) in ds.data_vars.items():
if arr.dims[0] == "variants":
vars_to_rechunk.append(var)
else:
vars_to_copy.append(var)

concat_zarrs_optimized(urls, output, vars_to_rechunk, vars_to_copy)


def vcf_to_zarr(
input: Union[PathType, Sequence[PathType]],
output: Union[PathType, MutableMapping[str, bytes]],
Expand All @@ -723,6 +790,7 @@ def vcf_to_zarr(
fields: Optional[Sequence[str]] = None,
exclude_fields: Optional[Sequence[str]] = None,
field_defs: Optional[Dict[str, Dict[str, Any]]] = None,
concat_algorithm: Optional[Literal["xarray_internal"]] = None,
) -> None:
"""Convert VCF files to a single Zarr on-disk store.
Expand All @@ -735,8 +803,7 @@ def vcf_to_zarr(
is None.
For more control over these two steps, consider using :func:`vcf_to_zarrs` followed by
:func:`zarrs_to_dataset`, then saving the dataset using Xarray's
:meth:`xarray.Dataset.to_zarr` method.
:func:`concat_zarrs`.
Parameters
----------
Expand Down Expand Up @@ -811,6 +878,10 @@ def vcf_to_zarr(
(which is defined as Number 2 in the VCF header) as ``haplotypes``.
(Note that Number ``A`` is the number of alternate alleles, see section 1.4.2 of the
VCF spec https://samtools.github.io/hts-specs/VCFv4.3.pdf.)
concat_algorithm
The algorithm to use to concatenate and rechunk Zarr files. The default None means
use the optimized version suitable for large files, whereas ``xarray_internal`` will
use built-in Xarray APIs, which can exhibit high memory usage, see https://github.com/dask/dask/issues/6745.
"""

if temp_chunk_length is not None:
Expand Down Expand Up @@ -842,6 +913,7 @@ def vcf_to_zarr(
temp_chunk_length=temp_chunk_length,
tempdir=tempdir,
tempdir_storage_options=tempdir_storage_options,
concat_algorithm=concat_algorithm,
)
convert_func(
input, # type: ignore
Expand Down
52 changes: 38 additions & 14 deletions sgkit/io/vcfzarr_reader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import tempfile
from pathlib import Path
from typing import Any, Dict, Hashable, List, Optional, Tuple
from typing import (
Any,
Dict,
Hashable,
List,
MutableMapping,
Optional,
Sequence,
Tuple,
Union,
)

import dask
import dask.array as da
import numcodecs
import xarray as xr
import zarr
from fsspec import get_mapper
Expand Down Expand Up @@ -151,8 +162,8 @@ def vcfzarr_to_zarr(
ds.to_zarr(output, mode="w")
else:
# Use the optimized algorithm in `concatenate_and_rechunk`
_concat_zarrs_optimized(
zarr_files, output, vars_to_rechunk, vars_to_copy
concat_zarrs_optimized(
zarr_files, output, vars_to_rechunk, vars_to_copy, fix_strings=True
)


Expand Down Expand Up @@ -291,18 +302,22 @@ def _get_max_len(zarr_groups: List[zarr.Group], attr_name: str) -> int:
return max_len


def _concat_zarrs_optimized(
zarr_files: List[str],
output: PathType,
def concat_zarrs_optimized(
zarr_files: Sequence[str],
output: Union[PathType, MutableMapping[str, bytes]],
vars_to_rechunk: List[Hashable],
vars_to_copy: List[Hashable],
fix_strings: bool = False,
) -> None:
if isinstance(output, Path):
output = str(output)

zarr_groups = [zarr.open_group(f) for f in zarr_files]

first_zarr_group = zarr_groups[0]

# create the top-level group
zarr.open_group(str(output), mode="w")
zarr.open_group(output, mode="w")

# copy variables that are to be rechunked
# NOTE: that this uses _to_zarr function defined here that is needed to avoid
Expand All @@ -311,38 +326,47 @@ def _concat_zarrs_optimized(
delayed = [] # do all the rechunking operations in one computation
for var in vars_to_rechunk:
dtype = None
if var in {"variant_id", "variant_allele"}:
if fix_strings and var in {"variant_id", "variant_allele"}:
max_len = _get_max_len(zarr_groups, f"max_length_{var}")
dtype = f"S{max_len}"

arr = concatenate_and_rechunk(
[group[var] for group in zarr_groups], dtype=dtype
)

_to_zarr_kwargs = dict(fill_value=None)
if not fix_strings and arr.dtype == "O":
# We assume that all object dtypes are variable length strings
_to_zarr_kwargs["object_codec"] = numcodecs.VLenUTF8()

d = _to_zarr( # type: ignore[no-untyped-call]
arr,
str(output),
output,
component=var,
overwrite=True,
compute=False,
fill_value=None,
attrs=first_zarr_group[var].attrs.asdict(),
**_to_zarr_kwargs,
)
delayed.append(d)
da.compute(*delayed)

# copy unchanged variables and top-level metadata
with zarr.open_group(str(output)) as output_zarr:
with zarr.open_group(output) as output_zarr:

# copy variables that are not rechunked (e.g. sample_id)
for var in vars_to_copy:
output_zarr[var] = first_zarr_group[var]
output_zarr[var].attrs.update(first_zarr_group[var].attrs)

# copy top-level attributes
output_zarr.attrs.update(first_zarr_group.attrs)
group_attrs = dict(first_zarr_group.attrs)
if "max_alt_alleles_seen" in group_attrs:
max_alt_alleles_seen = _get_max_len(zarr_groups, "max_alt_alleles_seen")
group_attrs["max_alt_alleles_seen"] = max_alt_alleles_seen
output_zarr.attrs.update(group_attrs)

# consolidate metadata
zarr.consolidate_metadata(str(output))
zarr.consolidate_metadata(output)


def _to_zarr( # type: ignore[no-untyped-def]
Expand Down
30 changes: 22 additions & 8 deletions sgkit/tests/io/vcf/test_vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,23 @@ def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path):
"is_path",
[True, False],
)
@pytest.mark.parametrize(
"concat_algorithm",
[None, "xarray_internal"],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path):
def test_vcf_to_zarr__parallel(shared_datadir, is_path, concat_algorithm, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
regions = ["20", "21"]

vcf_to_zarr(path, output, regions=regions, chunk_length=5_000)
vcf_to_zarr(
path,
output,
regions=regions,
chunk_length=5_000,
concat_algorithm=concat_algorithm,
)
ds = xr.open_zarr(output)

assert ds["sample_id"].shape == (1,)
Expand All @@ -252,8 +262,12 @@ def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path):
assert ds["variant_id_mask"].shape == (19910,)
assert ds["variant_position"].shape == (19910,)

assert ds["variant_allele"].dtype == "S48"
assert ds["variant_id"].dtype == "S1"
if concat_algorithm is None:
assert ds["variant_allele"].dtype == "O"
assert ds["variant_id"].dtype == "O"
else:
assert ds["variant_allele"].dtype == "S48"
assert ds["variant_id"].dtype == "S1"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -305,8 +319,8 @@ def test_vcf_to_zarr__parallel_temp_chunk_length(shared_datadir, is_path, tmp_pa
assert ds["variant_id_mask"].shape == (19910,)
assert ds["variant_position"].shape == (19910,)

assert ds["variant_allele"].dtype == "S48"
assert ds["variant_id"].dtype == "S1"
assert ds["variant_allele"].dtype == "O"
assert ds["variant_id"].dtype == "O"


def test_vcf_to_zarr__parallel_temp_chunk_length_not_divisible(
Expand Down Expand Up @@ -531,7 +545,7 @@ def test_vcf_to_zarr__mixed_ploidy_vcf(
)
ds = load_dataset(output)

variant_dtype = "|S1" if regions else "O"
variant_dtype = "O"
assert ds.attrs["contigs"] == ["CHR1", "CHR2", "CHR3"]
assert_array_equal(ds["variant_contig"], [0, 0])
assert_array_equal(ds["variant_position"], [2, 7])
Expand Down Expand Up @@ -728,7 +742,7 @@ def test_vcf_to_zarr__parallel_with_fields(shared_datadir, tmp_path):
assert_allclose(ds["variant_MQ"], [58.33, np.nan, 57.45])
assert ds["variant_MQ"].attrs["comment"] == "RMS Mapping Quality"

assert_array_equal(ds["call_PGT"], [[b"0|1"], [b""], [b"0|1"]])
assert_array_equal(ds["call_PGT"], [["0|1"], [""], ["0|1"]])
assert (
ds["call_PGT"].attrs["comment"]
== "Physical phasing haplotype information, describing how the alternate alleles are phased in relation to one another"
Expand Down
2 changes: 1 addition & 1 deletion sgkit/tests/test_vcfzarr_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_vcfzarr_to_zarr(
consolidated=consolidated,
)

ds = xr.open_zarr(output)
ds = xr.open_zarr(output, concat_characters=False)

# Note that variant_allele values are byte strings, not unicode strings (unlike for read_vcfzarr)
# We should make the two consistent.
Expand Down

0 comments on commit 45f1267

Please sign in to comment.