diff --git a/sgkit/io/vcf/vcf_reader.py b/sgkit/io/vcf/vcf_reader.py index 59b54ea16..0498e4ddb 100644 --- a/sgkit/io/vcf/vcf_reader.py +++ b/sgkit/io/vcf/vcf_reader.py @@ -22,13 +22,17 @@ import xarray as xr from cyvcf2 import VCF, Variant from numcodecs import PackBits +from typing_extensions import Literal from sgkit import variables from sgkit.io.dataset import load_dataset from sgkit.io.utils import zarrs_to_dataset from sgkit.io.vcf import partition_into_regions from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename -from sgkit.io.vcfzarr_reader import vcf_number_to_dimension_and_size +from sgkit.io.vcfzarr_reader import ( + concat_zarrs_optimized, + vcf_number_to_dimension_and_size, +) from sgkit.model import ( DIM_PLOIDY, DIM_SAMPLE, @@ -529,6 +533,7 @@ def vcf_to_zarr_parallel( fields: Optional[Sequence[str]] = None, exclude_fields: Optional[Sequence[str]] = None, field_defs: Optional[Dict[str, Dict[str, Any]]] = None, + concat_algorithm: Optional[Literal["xarray_internal"]] = None, ) -> None: """Convert specified regions of one or more VCF files to zarr files, then concat, rechunk, write to zarr""" @@ -557,11 +562,15 @@ def vcf_to_zarr_parallel( field_defs=field_defs, ) - ds = zarrs_to_dataset(paths, chunk_length, chunk_width, tempdir_storage_options) - - # Ensure Dask task graph is efficient, see https://github.com/dask/dask/issues/5105 - with dask.config.set({"optimization.fuse.ave-width": dask_fuse_avg_width}): - ds.to_zarr(output, mode="w") + concat_zarrs( + paths, + output, + concat_algorithm=concat_algorithm, + chunk_length=chunk_length, + chunk_width=chunk_width, + storage_options=tempdir_storage_options, + dask_fuse_avg_width=dask_fuse_avg_width, + ) def vcf_to_zarrs( @@ -703,6 +712,64 @@ def vcf_to_zarrs( return parts +def concat_zarrs( + urls: Sequence[str], + output: Union[PathType, MutableMapping[str, bytes]], + *, + concat_algorithm: Optional[Literal["xarray_internal"]] = None, + chunk_length: int = 10_000, + chunk_width: int = 1_000, + storage_options: Optional[Dict[str, str]] = None, + dask_fuse_avg_width: int = 50, +) -> None: + """Concatenate multiple Zarr stores into a single Zarr store. + + The Zarr stores are concatenated and rechunked to produce a single combined store. + + Parameters + ---------- + urls + A list of URLs to the Zarr stores to combine, typically the return value of + :func:`vcf_to_zarrs`. + output + Zarr store or path to directory in file system. + concat_algorithm + The algorithm to use to concatenate and rechunk Zarr files. The default None means + use the optimized version suitable for large files, whereas ``xarray_internal`` will + use built-in Xarray APIs, which can exhibit high memory usage, see https://github.com/dask/dask/issues/6745. + chunk_length + Length (number of variants) of chunks in which data are stored, by default 10,000. + This is only used when ``concat_algorithm`` is ``xarray_internal``. + chunk_width + Width (number of samples) to use when storing chunks in output, by default 1,000. + This is only used when ``concat_algorithm`` is ``xarray_internal``. + storage_options + Any additional parameters for the storage backend (see ``fsspec.open``). + dask_fuse_avg_width + Setting for Dask's ``optimization.fuse.ave-width``, see https://github.com/dask/dask/issues/5105 + """ + if concat_algorithm == "xarray_internal": + ds = zarrs_to_dataset(urls, chunk_length, chunk_width, storage_options) + + with dask.config.set({"optimization.fuse.ave-width": dask_fuse_avg_width}): + ds.to_zarr(output, mode="w") + else: + + vars_to_rechunk = [] + vars_to_copy = [] + storage_options = storage_options or {} + ds = xr.open_zarr( # type: ignore[no-untyped-call] + fsspec.get_mapper(urls[0], **storage_options), concat_characters=False + ) + for (var, arr) in ds.data_vars.items(): + if arr.dims[0] == "variants": + vars_to_rechunk.append(var) + else: + vars_to_copy.append(var) + + concat_zarrs_optimized(urls, output, vars_to_rechunk, vars_to_copy) + + def vcf_to_zarr( input: Union[PathType, Sequence[PathType]], output: Union[PathType, MutableMapping[str, bytes]], @@ -723,6 +790,7 @@ def vcf_to_zarr( fields: Optional[Sequence[str]] = None, exclude_fields: Optional[Sequence[str]] = None, field_defs: Optional[Dict[str, Dict[str, Any]]] = None, + concat_algorithm: Optional[Literal["xarray_internal"]] = None, ) -> None: """Convert VCF files to a single Zarr on-disk store. @@ -735,8 +803,7 @@ def vcf_to_zarr( is None. For more control over these two steps, consider using :func:`vcf_to_zarrs` followed by - :func:`zarrs_to_dataset`, then saving the dataset using Xarray's - :meth:`xarray.Dataset.to_zarr` method. + :func:`concat_zarrs`. Parameters ---------- @@ -811,6 +878,10 @@ def vcf_to_zarr( (which is defined as Number 2 in the VCF header) as ``haplotypes``. (Note that Number ``A`` is the number of alternate alleles, see section 1.4.2 of the VCF spec https://samtools.github.io/hts-specs/VCFv4.3.pdf.) + concat_algorithm + The algorithm to use to concatenate and rechunk Zarr files. The default None means + use the optimized version suitable for large files, whereas ``xarray_internal`` will + use built-in Xarray APIs, which can exhibit high memory usage, see https://github.com/dask/dask/issues/6745. """ if temp_chunk_length is not None: @@ -842,6 +913,7 @@ def vcf_to_zarr( temp_chunk_length=temp_chunk_length, tempdir=tempdir, tempdir_storage_options=tempdir_storage_options, + concat_algorithm=concat_algorithm, ) convert_func( input, # type: ignore diff --git a/sgkit/io/vcfzarr_reader.py b/sgkit/io/vcfzarr_reader.py index 4ca5cf154..16a8a442d 100644 --- a/sgkit/io/vcfzarr_reader.py +++ b/sgkit/io/vcfzarr_reader.py @@ -1,9 +1,20 @@ import tempfile from pathlib import Path -from typing import Any, Dict, Hashable, List, Optional, Tuple +from typing import ( + Any, + Dict, + Hashable, + List, + MutableMapping, + Optional, + Sequence, + Tuple, + Union, +) import dask import dask.array as da +import numcodecs import xarray as xr import zarr from fsspec import get_mapper @@ -151,8 +162,8 @@ def vcfzarr_to_zarr( ds.to_zarr(output, mode="w") else: # Use the optimized algorithm in `concatenate_and_rechunk` - _concat_zarrs_optimized( - zarr_files, output, vars_to_rechunk, vars_to_copy + concat_zarrs_optimized( + zarr_files, output, vars_to_rechunk, vars_to_copy, fix_strings=True ) @@ -291,18 +302,22 @@ def _get_max_len(zarr_groups: List[zarr.Group], attr_name: str) -> int: return max_len -def _concat_zarrs_optimized( - zarr_files: List[str], - output: PathType, +def concat_zarrs_optimized( + zarr_files: Sequence[str], + output: Union[PathType, MutableMapping[str, bytes]], vars_to_rechunk: List[Hashable], vars_to_copy: List[Hashable], + fix_strings: bool = False, ) -> None: + if isinstance(output, Path): + output = str(output) + zarr_groups = [zarr.open_group(f) for f in zarr_files] first_zarr_group = zarr_groups[0] # create the top-level group - zarr.open_group(str(output), mode="w") + zarr.open_group(output, mode="w") # copy variables that are to be rechunked # NOTE: that this uses _to_zarr function defined here that is needed to avoid @@ -311,27 +326,32 @@ def _concat_zarrs_optimized( delayed = [] # do all the rechunking operations in one computation for var in vars_to_rechunk: dtype = None - if var in {"variant_id", "variant_allele"}: + if fix_strings and var in {"variant_id", "variant_allele"}: max_len = _get_max_len(zarr_groups, f"max_length_{var}") dtype = f"S{max_len}" - arr = concatenate_and_rechunk( [group[var] for group in zarr_groups], dtype=dtype ) + + _to_zarr_kwargs = dict(fill_value=None) + if not fix_strings and arr.dtype == "O": + # We assume that all object dtypes are variable length strings + _to_zarr_kwargs["object_codec"] = numcodecs.VLenUTF8() + d = _to_zarr( # type: ignore[no-untyped-call] arr, - str(output), + output, component=var, overwrite=True, compute=False, - fill_value=None, attrs=first_zarr_group[var].attrs.asdict(), + **_to_zarr_kwargs, ) delayed.append(d) da.compute(*delayed) # copy unchanged variables and top-level metadata - with zarr.open_group(str(output)) as output_zarr: + with zarr.open_group(output) as output_zarr: # copy variables that are not rechunked (e.g. sample_id) for var in vars_to_copy: @@ -339,10 +359,14 @@ def _concat_zarrs_optimized( output_zarr[var].attrs.update(first_zarr_group[var].attrs) # copy top-level attributes - output_zarr.attrs.update(first_zarr_group.attrs) + group_attrs = dict(first_zarr_group.attrs) + if "max_alt_alleles_seen" in group_attrs: + max_alt_alleles_seen = _get_max_len(zarr_groups, "max_alt_alleles_seen") + group_attrs["max_alt_alleles_seen"] = max_alt_alleles_seen + output_zarr.attrs.update(group_attrs) # consolidate metadata - zarr.consolidate_metadata(str(output)) + zarr.consolidate_metadata(output) def _to_zarr( # type: ignore[no-untyped-def] diff --git a/sgkit/tests/io/vcf/test_vcf_reader.py b/sgkit/tests/io/vcf/test_vcf_reader.py index 497483fee..bd80c96f7 100644 --- a/sgkit/tests/io/vcf/test_vcf_reader.py +++ b/sgkit/tests/io/vcf/test_vcf_reader.py @@ -233,13 +233,23 @@ def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path): "is_path", [True, False], ) +@pytest.mark.parametrize( + "concat_algorithm", + [None, "xarray_internal"], +) @pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path): +def test_vcf_to_zarr__parallel(shared_datadir, is_path, concat_algorithm, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) output = tmp_path.joinpath("vcf_concat.zarr").as_posix() regions = ["20", "21"] - vcf_to_zarr(path, output, regions=regions, chunk_length=5_000) + vcf_to_zarr( + path, + output, + regions=regions, + chunk_length=5_000, + concat_algorithm=concat_algorithm, + ) ds = xr.open_zarr(output) assert ds["sample_id"].shape == (1,) @@ -252,8 +262,12 @@ def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path): assert ds["variant_id_mask"].shape == (19910,) assert ds["variant_position"].shape == (19910,) - assert ds["variant_allele"].dtype == "S48" - assert ds["variant_id"].dtype == "S1" + if concat_algorithm is None: + assert ds["variant_allele"].dtype == "O" + assert ds["variant_id"].dtype == "O" + else: + assert ds["variant_allele"].dtype == "S48" + assert ds["variant_id"].dtype == "S1" @pytest.mark.parametrize( @@ -305,8 +319,8 @@ def test_vcf_to_zarr__parallel_temp_chunk_length(shared_datadir, is_path, tmp_pa assert ds["variant_id_mask"].shape == (19910,) assert ds["variant_position"].shape == (19910,) - assert ds["variant_allele"].dtype == "S48" - assert ds["variant_id"].dtype == "S1" + assert ds["variant_allele"].dtype == "O" + assert ds["variant_id"].dtype == "O" def test_vcf_to_zarr__parallel_temp_chunk_length_not_divisible( @@ -531,7 +545,7 @@ def test_vcf_to_zarr__mixed_ploidy_vcf( ) ds = load_dataset(output) - variant_dtype = "|S1" if regions else "O" + variant_dtype = "O" assert ds.attrs["contigs"] == ["CHR1", "CHR2", "CHR3"] assert_array_equal(ds["variant_contig"], [0, 0]) assert_array_equal(ds["variant_position"], [2, 7]) @@ -728,7 +742,7 @@ def test_vcf_to_zarr__parallel_with_fields(shared_datadir, tmp_path): assert_allclose(ds["variant_MQ"], [58.33, np.nan, 57.45]) assert ds["variant_MQ"].attrs["comment"] == "RMS Mapping Quality" - assert_array_equal(ds["call_PGT"], [[b"0|1"], [b""], [b"0|1"]]) + assert_array_equal(ds["call_PGT"], [["0|1"], [""], ["0|1"]]) assert ( ds["call_PGT"].attrs["comment"] == "Physical phasing haplotype information, describing how the alternate alleles are phased in relation to one another" diff --git a/sgkit/tests/test_vcfzarr_reader.py b/sgkit/tests/test_vcfzarr_reader.py index c9f9f9329..03a6638a9 100644 --- a/sgkit/tests/test_vcfzarr_reader.py +++ b/sgkit/tests/test_vcfzarr_reader.py @@ -140,7 +140,7 @@ def test_vcfzarr_to_zarr( consolidated=consolidated, ) - ds = xr.open_zarr(output) + ds = xr.open_zarr(output, concat_characters=False) # Note that variant_allele values are byte strings, not unicode strings (unlike for read_vcfzarr) # We should make the two consistent.