Skip to content

Commit

Permalink
Issue a warning if the number of alt alleles exceeds the maximum spec…
Browse files Browse the repository at this point in the history
…ified
  • Loading branch information
tomwhite committed Jul 6, 2021
1 parent a7cb326 commit 439ed89
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 19 deletions.
5 changes: 5 additions & 0 deletions sgkit/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def zarrs_to_dataset(
ds[variable_name] = ds[variable_name].astype(f"S{max_length}")
del ds.attrs[attr]

if "max_alt_alleles_seen" in datasets[0].attrs:
ds.attrs["max_alt_alleles_seen"] = max(
ds.attrs["max_alt_alleles_seen"] for ds in datasets
)

return ds


Expand Down
3 changes: 2 additions & 1 deletion sgkit/io/vcf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
try:
from ..utils import zarrs_to_dataset
from .vcf_partition import partition_into_regions
from .vcf_reader import vcf_to_zarr, vcf_to_zarrs
from .vcf_reader import MaxAltAllelesExceededWarning, vcf_to_zarr, vcf_to_zarrs

__all__ = [
"MaxAltAllelesExceededWarning",
"partition_into_regions",
"vcf_to_zarr",
"vcf_to_zarrs",
Expand Down
20 changes: 20 additions & 0 deletions sgkit/io/vcf/vcf_reader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import itertools
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -21,6 +22,7 @@
import xarray as xr
from cyvcf2 import VCF, Variant

from sgkit.io.dataset import load_dataset
from sgkit.io.utils import zarrs_to_dataset
from sgkit.io.vcf import partition_into_regions
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename
Expand All @@ -34,6 +36,12 @@
)


class MaxAltAllelesExceededWarning(UserWarning):
"""Warning when the number of alt alleles exceeds the maximum specified."""

pass


@contextmanager
def open_vcf(path: PathType) -> Iterator[VCF]:
"""A context manager for opening a VCF file."""
Expand Down Expand Up @@ -256,6 +264,7 @@ def vcf_to_zarr_sequential(
# Remember max lengths of variable-length strings
max_variant_id_length = 0
max_variant_allele_length = 0
max_alt_alleles_seen = 0

# Iterate through variants in batches of chunk_length

Expand Down Expand Up @@ -303,6 +312,7 @@ def vcf_to_zarr_sequential(
variant_position[i] = variant.POS

alleles = [variant.REF] + variant.ALT
max_alt_alleles_seen = max(max_alt_alleles_seen, len(variant.ALT))
if len(alleles) > n_allele:
alleles = alleles[:n_allele]
elif len(alleles) < n_allele:
Expand Down Expand Up @@ -359,6 +369,7 @@ def vcf_to_zarr_sequential(
if add_str_max_length_attrs:
ds.attrs["max_length_variant_id"] = max_variant_id_length
ds.attrs["max_length_variant_allele"] = max_variant_allele_length
ds.attrs["max_alt_alleles_seen"] = max_alt_alleles_seen

if first_variants_chunk:
# Enforce uniform chunks in the variants dimension
Expand Down Expand Up @@ -705,6 +716,15 @@ def vcf_to_zarr(
field_defs=field_defs,
)

# Issue a warning if max_alt_alleles caused data to be dropped
ds = load_dataset(output)
max_alt_alleles_seen = ds.attrs["max_alt_alleles_seen"]
if max_alt_alleles_seen > max_alt_alleles:
warnings.warn(
f"Some alternate alleles were dropped, since actual max value {max_alt_alleles_seen} exceeded max_alt_alleles setting of {max_alt_alleles}.",
MaxAltAllelesExceededWarning,
)


def count_variants(path: PathType, region: Optional[str] = None) -> int:
"""Count the number of variants in a VCF file."""
Expand Down
79 changes: 61 additions & 18 deletions sgkit/tests/io/vcf/test_vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from numpy.testing import assert_allclose, assert_array_equal

from sgkit import load_dataset
from sgkit.io.vcf import partition_into_regions, vcf_to_zarr
from sgkit.io.vcf import (
MaxAltAllelesExceededWarning,
partition_into_regions,
vcf_to_zarr,
)

from .utils import path_for_test

Expand Down Expand Up @@ -96,30 +100,35 @@ def test_vcf_to_zarr__max_alt_alleles(shared_datadir, is_path, tmp_path):
path = path_for_test(shared_datadir, "sample.vcf.gz", is_path)
output = tmp_path.joinpath("vcf.zarr").as_posix()

vcf_to_zarr(path, output, chunk_length=5, chunk_width=2, max_alt_alleles=1)
ds = xr.open_zarr(output) # type: ignore[no-untyped-call]
with pytest.warns(MaxAltAllelesExceededWarning):
vcf_to_zarr(path, output, chunk_length=5, chunk_width=2, max_alt_alleles=1)
ds = xr.open_zarr(output) # type: ignore[no-untyped-call]

# extra alt alleles are silently dropped
assert_array_equal(
ds["variant_allele"],
[
["A", "C"],
["A", "G"],
["G", "A"],
["T", "A"],
["A", "G"],
["T", ""],
["G", "GA"],
["T", ""],
["AC", "A"],
],
)
# extra alt alleles are dropped
assert_array_equal(
ds["variant_allele"],
[
["A", "C"],
["A", "G"],
["G", "A"],
["T", "A"],
["A", "G"],
["T", ""],
["G", "GA"],
["T", ""],
["AC", "A"],
],
)

# the maximum number of alt alleles actually seen is stored as an attribute
assert ds.attrs["max_alt_alleles_seen"] == 3


@pytest.mark.parametrize(
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__large_vcf(shared_datadir, is_path, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
output = tmp_path.joinpath("vcf.zarr").as_posix()
Expand Down Expand Up @@ -157,6 +166,7 @@ def test_vcf_to_zarr__plain_vcf_with_no_index(shared_datadir, tmp_path):
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__mutable_mapping(shared_datadir, is_path):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
output: MutableMapping[str, bytes] = {}
Expand All @@ -182,6 +192,7 @@ def test_vcf_to_zarr__mutable_mapping(shared_datadir, is_path):
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
Expand All @@ -208,6 +219,7 @@ def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path):
"is_path",
[False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__parallel_temp_chunk_length(shared_datadir, is_path, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
Expand Down Expand Up @@ -296,6 +308,7 @@ def test_vcf_to_zarr__parallel_partitioned_by_size(shared_datadir, is_path, tmp_
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path):
paths = [
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
Expand Down Expand Up @@ -323,6 +336,7 @@ def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path):
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path):
paths = [
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
Expand Down Expand Up @@ -352,6 +366,7 @@ def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path):
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__multiple_partitioned_by_size(shared_datadir, is_path, tmp_path):
paths = [
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
Expand Down Expand Up @@ -398,6 +413,31 @@ def test_vcf_to_zarr__mutiple_partitioned_invalid_regions(
vcf_to_zarr(paths, output, regions=regions, chunk_length=5_000)


@pytest.mark.parametrize(
"is_path",
[True, False],
)
def test_vcf_to_zarr__multiple_max_alt_alleles(shared_datadir, is_path, tmp_path):
paths = [
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path),
]
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()

with pytest.warns(MaxAltAllelesExceededWarning):
vcf_to_zarr(
paths,
output,
target_part_size="40KB",
chunk_length=5_000,
max_alt_alleles=1,
)
ds = xr.open_zarr(output) # type: ignore[no-untyped-call]

# the maximum number of alt alleles actually seen is stored as an attribute
assert ds.attrs["max_alt_alleles_seen"] == 7


@pytest.mark.parametrize(
"ploidy,mixed_ploidy,truncate_calls,regions",
[
Expand Down Expand Up @@ -560,6 +600,7 @@ def test_vcf_to_zarr__fields(shared_datadir, tmp_path):
assert ds["call_DP"].attrs["comment"] == "Read Depth"


@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__parallel_with_fields(shared_datadir, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz")
output = tmp_path.joinpath("vcf.zarr").as_posix()
Expand Down Expand Up @@ -616,6 +657,7 @@ def test_vcf_to_zarr__field_defs(shared_datadir, tmp_path):
assert "comment" not in ds["variant_DP"].attrs


@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path):
path = path_for_test(shared_datadir, "sample.vcf.gz")
output = tmp_path.joinpath("vcf.zarr").as_posix()
Expand Down Expand Up @@ -649,6 +691,7 @@ def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path):
)


@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__field_number_R(shared_datadir, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz")
output = tmp_path.joinpath("vcf.zarr").as_posix()
Expand Down
4 changes: 4 additions & 0 deletions sgkit/tests/io/vcf/test_vcf_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_default_fields(shared_datadir, tmpdir):
sg_vcfzarr_path = create_sg_vcfzarr(shared_datadir, tmpdir)
sg_ds = sg.load_dataset(str(sg_vcfzarr_path))
sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel
del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel

assert_identical(allel_ds, sg_ds)

Expand Down Expand Up @@ -107,6 +108,7 @@ def test_DP_field(shared_datadir, tmpdir):
)
sg_ds = sg.load_dataset(str(sg_vcfzarr_path))
sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel
del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel

assert_identical(allel_ds, sg_ds)

Expand All @@ -120,6 +122,7 @@ def test_DP_field(shared_datadir, tmpdir):
("CEUTrio.20.21.gatk3.4.g.vcf.bgz", ["calldata/PL"], ["FORMAT/PL"]),
],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_all_fields(
shared_datadir, tmpdir, vcf_file, allel_exclude_fields, sgkit_exclude_fields
):
Expand Down Expand Up @@ -159,6 +162,7 @@ def test_all_fields(
)
sg_ds = sg.load_dataset(str(sg_vcfzarr_path))
sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel
del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel

# scikit-allel only records contigs for which there are actual variants,
# whereas sgkit records contigs from the header
Expand Down

0 comments on commit 439ed89

Please sign in to comment.