diff --git a/sgkit/io/utils.py b/sgkit/io/utils.py index f11331c54..ae8162c67 100644 --- a/sgkit/io/utils.py +++ b/sgkit/io/utils.py @@ -104,6 +104,11 @@ def zarrs_to_dataset( ds[variable_name] = ds[variable_name].astype(f"S{max_length}") del ds.attrs[attr] + if "max_alt_alleles_seen" in datasets[0].attrs: + ds.attrs["max_alt_alleles_seen"] = max( + ds.attrs["max_alt_alleles_seen"] for ds in datasets + ) + return ds diff --git a/sgkit/io/vcf/__init__.py b/sgkit/io/vcf/__init__.py index 4d93e0fd2..a6eb71e7a 100644 --- a/sgkit/io/vcf/__init__.py +++ b/sgkit/io/vcf/__init__.py @@ -3,9 +3,10 @@ try: from ..utils import zarrs_to_dataset from .vcf_partition import partition_into_regions - from .vcf_reader import vcf_to_zarr, vcf_to_zarrs + from .vcf_reader import MaxAltAllelesExceededWarning, vcf_to_zarr, vcf_to_zarrs __all__ = [ + "MaxAltAllelesExceededWarning", "partition_into_regions", "vcf_to_zarr", "vcf_to_zarrs", diff --git a/sgkit/io/vcf/vcf_reader.py b/sgkit/io/vcf/vcf_reader.py index 183114f37..e2185a422 100644 --- a/sgkit/io/vcf/vcf_reader.py +++ b/sgkit/io/vcf/vcf_reader.py @@ -1,5 +1,6 @@ import functools import itertools +import warnings from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path @@ -21,6 +22,7 @@ import xarray as xr from cyvcf2 import VCF, Variant +from sgkit.io.dataset import load_dataset from sgkit.io.utils import zarrs_to_dataset from sgkit.io.vcf import partition_into_regions from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename @@ -34,6 +36,12 @@ ) +class MaxAltAllelesExceededWarning(UserWarning): + """Warning when the number of alt alleles exceeds the maximum specified.""" + + pass + + @contextmanager def open_vcf(path: PathType) -> Iterator[VCF]: """A context manager for opening a VCF file.""" @@ -256,6 +264,7 @@ def vcf_to_zarr_sequential( # Remember max lengths of variable-length strings max_variant_id_length = 0 max_variant_allele_length = 0 + max_alt_alleles_seen = 0 # Iterate through variants in batches of chunk_length @@ -303,6 +312,7 @@ def vcf_to_zarr_sequential( variant_position[i] = variant.POS alleles = [variant.REF] + variant.ALT + max_alt_alleles_seen = max(max_alt_alleles_seen, len(variant.ALT)) if len(alleles) > n_allele: alleles = alleles[:n_allele] elif len(alleles) < n_allele: @@ -359,6 +369,7 @@ def vcf_to_zarr_sequential( if add_str_max_length_attrs: ds.attrs["max_length_variant_id"] = max_variant_id_length ds.attrs["max_length_variant_allele"] = max_variant_allele_length + ds.attrs["max_alt_alleles_seen"] = max_alt_alleles_seen if first_variants_chunk: # Enforce uniform chunks in the variants dimension @@ -705,6 +716,15 @@ def vcf_to_zarr( field_defs=field_defs, ) + # Issue a warning if max_alt_alleles caused data to be dropped + ds = load_dataset(output) + max_alt_alleles_seen = ds.attrs["max_alt_alleles_seen"] + if max_alt_alleles_seen > max_alt_alleles: + warnings.warn( + f"Some alternate alleles were dropped, since actual max value {max_alt_alleles_seen} exceeded max_alt_alleles setting of {max_alt_alleles}.", + MaxAltAllelesExceededWarning, + ) + def count_variants(path: PathType, region: Optional[str] = None) -> int: """Count the number of variants in a VCF file.""" diff --git a/sgkit/tests/io/vcf/test_vcf_reader.py b/sgkit/tests/io/vcf/test_vcf_reader.py index e9187ba90..a78c33628 100644 --- a/sgkit/tests/io/vcf/test_vcf_reader.py +++ b/sgkit/tests/io/vcf/test_vcf_reader.py @@ -6,7 +6,11 @@ from numpy.testing import assert_allclose, assert_array_equal from sgkit import load_dataset -from sgkit.io.vcf import partition_into_regions, vcf_to_zarr +from sgkit.io.vcf import ( + MaxAltAllelesExceededWarning, + partition_into_regions, + vcf_to_zarr, +) from .utils import path_for_test @@ -96,30 +100,35 @@ def test_vcf_to_zarr__max_alt_alleles(shared_datadir, is_path, tmp_path): path = path_for_test(shared_datadir, "sample.vcf.gz", is_path) output = tmp_path.joinpath("vcf.zarr").as_posix() - vcf_to_zarr(path, output, chunk_length=5, chunk_width=2, max_alt_alleles=1) - ds = xr.open_zarr(output) # type: ignore[no-untyped-call] + with pytest.warns(MaxAltAllelesExceededWarning): + vcf_to_zarr(path, output, chunk_length=5, chunk_width=2, max_alt_alleles=1) + ds = xr.open_zarr(output) # type: ignore[no-untyped-call] - # extra alt alleles are silently dropped - assert_array_equal( - ds["variant_allele"], - [ - ["A", "C"], - ["A", "G"], - ["G", "A"], - ["T", "A"], - ["A", "G"], - ["T", ""], - ["G", "GA"], - ["T", ""], - ["AC", "A"], - ], - ) + # extra alt alleles are dropped + assert_array_equal( + ds["variant_allele"], + [ + ["A", "C"], + ["A", "G"], + ["G", "A"], + ["T", "A"], + ["A", "G"], + ["T", ""], + ["G", "GA"], + ["T", ""], + ["AC", "A"], + ], + ) + + # the maximum number of alt alleles actually seen is stored as an attribute + assert ds.attrs["max_alt_alleles_seen"] == 3 @pytest.mark.parametrize( "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__large_vcf(shared_datadir, is_path, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) output = tmp_path.joinpath("vcf.zarr").as_posix() @@ -157,6 +166,7 @@ def test_vcf_to_zarr__plain_vcf_with_no_index(shared_datadir, tmp_path): "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__mutable_mapping(shared_datadir, is_path): path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) output: MutableMapping[str, bytes] = {} @@ -182,6 +192,7 @@ def test_vcf_to_zarr__mutable_mapping(shared_datadir, is_path): "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) output = tmp_path.joinpath("vcf_concat.zarr").as_posix() @@ -208,6 +219,7 @@ def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path): "is_path", [False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__parallel_temp_chunk_length(shared_datadir, is_path, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) output = tmp_path.joinpath("vcf_concat.zarr").as_posix() @@ -296,6 +308,7 @@ def test_vcf_to_zarr__parallel_partitioned_by_size(shared_datadir, is_path, tmp_ "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path): paths = [ path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), @@ -323,6 +336,7 @@ def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path): "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path): paths = [ path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), @@ -352,6 +366,7 @@ def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path): "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__multiple_partitioned_by_size(shared_datadir, is_path, tmp_path): paths = [ path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), @@ -398,6 +413,31 @@ def test_vcf_to_zarr__mutiple_partitioned_invalid_regions( vcf_to_zarr(paths, output, regions=regions, chunk_length=5_000) +@pytest.mark.parametrize( + "is_path", + [True, False], +) +def test_vcf_to_zarr__multiple_max_alt_alleles(shared_datadir, is_path, tmp_path): + paths = [ + path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), + path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), + ] + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + + with pytest.warns(MaxAltAllelesExceededWarning): + vcf_to_zarr( + paths, + output, + target_part_size="40KB", + chunk_length=5_000, + max_alt_alleles=1, + ) + ds = xr.open_zarr(output) # type: ignore[no-untyped-call] + + # the maximum number of alt alleles actually seen is stored as an attribute + assert ds.attrs["max_alt_alleles_seen"] == 7 + + @pytest.mark.parametrize( "ploidy,mixed_ploidy,truncate_calls,regions", [ @@ -560,6 +600,7 @@ def test_vcf_to_zarr__fields(shared_datadir, tmp_path): assert ds["call_DP"].attrs["comment"] == "Read Depth" +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__parallel_with_fields(shared_datadir, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz") output = tmp_path.joinpath("vcf.zarr").as_posix() @@ -616,6 +657,7 @@ def test_vcf_to_zarr__field_defs(shared_datadir, tmp_path): assert "comment" not in ds["variant_DP"].attrs +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path): path = path_for_test(shared_datadir, "sample.vcf.gz") output = tmp_path.joinpath("vcf.zarr").as_posix() @@ -649,6 +691,7 @@ def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path): ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__field_number_R(shared_datadir, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz") output = tmp_path.joinpath("vcf.zarr").as_posix() diff --git a/sgkit/tests/io/vcf/test_vcf_roundtrip.py b/sgkit/tests/io/vcf/test_vcf_roundtrip.py index a81a893ab..47494235b 100644 --- a/sgkit/tests/io/vcf/test_vcf_roundtrip.py +++ b/sgkit/tests/io/vcf/test_vcf_roundtrip.py @@ -79,6 +79,7 @@ def test_default_fields(shared_datadir, tmpdir): sg_vcfzarr_path = create_sg_vcfzarr(shared_datadir, tmpdir) sg_ds = sg.load_dataset(str(sg_vcfzarr_path)) sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel + del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel assert_identical(allel_ds, sg_ds) @@ -107,6 +108,7 @@ def test_DP_field(shared_datadir, tmpdir): ) sg_ds = sg.load_dataset(str(sg_vcfzarr_path)) sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel + del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel assert_identical(allel_ds, sg_ds) @@ -120,6 +122,7 @@ def test_DP_field(shared_datadir, tmpdir): ("CEUTrio.20.21.gatk3.4.g.vcf.bgz", ["calldata/PL"], ["FORMAT/PL"]), ], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_all_fields( shared_datadir, tmpdir, vcf_file, allel_exclude_fields, sgkit_exclude_fields ): @@ -159,6 +162,7 @@ def test_all_fields( ) sg_ds = sg.load_dataset(str(sg_vcfzarr_path)) sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel + del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel # scikit-allel only records contigs for which there are actual variants, # whereas sgkit records contigs from the header