diff --git a/kipoiseq/extractors/__init__.py b/kipoiseq/extractors/__init__.py index 46e52d7..6999804 100644 --- a/kipoiseq/extractors/__init__.py +++ b/kipoiseq/extractors/__init__.py @@ -6,3 +6,4 @@ from .vcf_matching import * from .multi_interval import * from .protein import * +from .variant_combinations import * diff --git a/kipoiseq/extractors/variant_combinations.py b/kipoiseq/extractors/variant_combinations.py new file mode 100644 index 0000000..904f58e --- /dev/null +++ b/kipoiseq/extractors/variant_combinations.py @@ -0,0 +1,98 @@ +from typing import Iterable +from itertools import product +from kipoiseq import Interval, Variant +from kipoiseq.utils import alphabets +from kipoiseq.extractors import FastaStringExtractor +from kipoiseq.extractors.vcf_matching import pyranges_to_intervals + + +class VariantCombinator: + + def __init__(self, fasta_file: str, bed_file: str = None, + variant_type='snv', alphabet='DNA'): + if variant_type not in {'all', 'snv', 'in', 'del'}: + raise ValueError("variant_type should be one of " + "{'all', 'snv', 'in', 'del'}") + + self.bed_file = bed_file + self.fasta = fasta_file + self.fasta = FastaStringExtractor(fasta_file, force_upper=True) + self.variant_type = variant_type + self.alphabet = alphabets[alphabet] + + def combination_variants_snv(self, interval: Interval) -> Iterable[Variant]: + """Returns all the possible variants in the regions. + + interval: interval of variants + """ + seq = self.fasta.extract(interval) + for pos, ref in zip(range(interval.start, interval.end), seq): + pos = pos + 1 # 0 to 1 base + for alt in self.alphabet: + if ref != alt: + yield Variant(interval.chrom, pos, ref, alt) + + def combination_variants_insertion(self, interval, length=2) -> Iterable[Variant]: + """Returns all the possible variants in the regions. + + interval: interval of variants + length: insertions up to length + """ + if length < 2: + raise ValueError('length argument should be larger than 1') + + seq = self.fasta.extract(interval) + for pos, ref in zip(range(interval.start, interval.end), seq): + pos = pos + 1 # 0 to 1 base + for l in range(2, length + 1): + for alt in product(self.alphabet, repeat=l): + yield Variant(interval.chrom, pos, ref, ''.join(alt)) + + def combination_variants_deletion(self, interval, length=1) -> Iterable[Variant]: + """Returns all the possible variants in the regions. + interval: interval of variants + length: deletions up to length + """ + if length < 1 and length <= interval.width: + raise ValueError('length argument should be larger than 0' + ' and smaller than interval witdh') + + seq = self.fasta.extract(interval) + for i, pos in enumerate(range(interval.start, interval.end)): + pos = pos + 1 # 0 to 1 base + for j in range(1, length + 1): + if i + j <= len(seq): + yield Variant(interval.chrom, pos, seq[i:i + j], '') + + def combination_variants(self, interval, variant_type='snv', + in_length=2, del_length=2) -> Iterable[Variant]: + if variant_type in {'snv', 'all'}: + yield from self.combination_variants_snv(interval) + if variant_type in {'indel', 'in', 'all'}: + yield from self.combination_variants_insertion( + interval, length=in_length) + if variant_type in {'indel', 'del', 'all'}: + yield from self.combination_variants_deletion( + interval, length=del_length) + + def __iter__(self) -> Iterable[Variant]: + import pyranges as pr + + gr = pr.read_bed(self.bed_file) + gr = gr.merge(strand=False).sort() + + for interval in pyranges_to_intervals(gr): + yield from self.combination_variants(interval, self.variant_type) + + def to_vcf(self, path): + from cyvcf2 import Writer + header = '''##fileformat=VCFv4.2 +#CHROM POS ID REF ALT QUAL FILTER INFO +''' + writer = Writer.from_string(path, header) + + for v in self: + variant = writer.variant_from_string('\t'.join([ + v.chrom, str(v.pos), '.', v.ref, v.alt, '.', '.', '.' + ])) + writer.write_record(variant) diff --git a/kipoiseq/extractors/vcf_query.py b/kipoiseq/extractors/vcf_query.py index 0fdcdcb..dc7f613 100644 --- a/kipoiseq/extractors/vcf_query.py +++ b/kipoiseq/extractors/vcf_query.py @@ -1,3 +1,4 @@ +import csv import abc from itertools import islice from typing import Tuple, Iterable, List @@ -241,3 +242,40 @@ def to_vcf(self, path, remove_samples=False, clean_info=False): variant = writer.variant_from_string('\t'.join(variant)) writer.write_record(variant) + + def to_sample_csv(self, path, format_fields=None): + """ + Extract samples and FORMAT from vcf then save as csv file. + """ + format_fields = format_fields or list() + writer = None + + with open(path, 'w') as f: + + for variant in self: + variant_fields = str(variant.source).strip().split('\t') + + if writer is None: + # FORMAT field + fieldnames = ['variant', 'sample', + 'genotype'] + format_fields + format_fields = set(format_fields) + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + samples = self.vcf.samples + + values = dict(zip(samples, map( + lambda x: x.split(':'), variant_fields[9:]))) + fields = variant_fields[8].split(':') + + for sample, gt in self.vcf.get_samples(variant).items(): + row = dict() + row['variant'] = str(variant) + row['sample'] = sample + row['genotype'] = gt + + for k, v in zip(fields, values[sample]): + if k in format_fields: + row[k] = v + + writer.writerow(row) diff --git a/kipoiseq/transforms/functional.py b/kipoiseq/transforms/functional.py index df8d9d5..e40b4ee 100644 --- a/kipoiseq/transforms/functional.py +++ b/kipoiseq/transforms/functional.py @@ -85,7 +85,7 @@ def tokenize(seq, alphabet=DNA, neutral_alphabet=["N"]): neutral_alphabet = [neutral_alphabet] nchar = len(alphabet[0]) - for l in alphabet + neutral_alphabet: + for l in (*alphabet, *neutral_alphabet): assert len(l) == nchar assert len(seq) % nchar == 0 # since we are using striding diff --git a/kipoiseq/utils.py b/kipoiseq/utils.py index 717a114..84463e8 100644 --- a/kipoiseq/utils.py +++ b/kipoiseq/utils.py @@ -4,13 +4,11 @@ import numpy as np from six import string_types -# alphabets: -from kipoiseq import Variant -DNA = ["A", "C", "G", "T"] -RNA = ["A", "C", "G", "U"] -AMINO_ACIDS = ["A", "R", "N", "D", "B", "C", "E", "Q", "Z", "G", "H", - "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"] +DNA = ("A", "C", "G", "T") +RNA = ("A", "C", "G", "U") +AMINO_ACIDS = ("A", "R", "N", "D", "B", "C", "E", "Q", "Z", "G", "H", + "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V") alphabets = {"DNA": DNA, "RNA": RNA, @@ -38,7 +36,8 @@ def parse_dtype(dtype): try: return eval(dtype) except Exception as e: - raise ValueError("Unable to parse dtype: {}. \nException: {}".format(dtype, e)) + raise ValueError( + "Unable to parse dtype: {}. \nException: {}".format(dtype, e)) else: return dtype diff --git a/tests/dataloaders/test_sequence.py b/tests/dataloaders/test_sequence.py index 8d9882c..5588577 100644 --- a/tests/dataloaders/test_sequence.py +++ b/tests/dataloaders/test_sequence.py @@ -15,6 +15,7 @@ def fasta_file(): def intervals_file(): return "tests/data/sample_intervals.bed" + @pytest.fixture def intervals_file_strand(): return "tests/data/sample_interval_strand.bed" @@ -25,8 +26,9 @@ def intervals_file_strand(): def test_min_props(): # minimal set of properties that need to be specified on the object - min_set_props = ["output_schema", "type", "defined_as", "info", "args", "dependencies", "postprocessing", - "source", "source_dir"] + min_set_props = ["output_schema", "type", "defined_as", "info", "args", + "dependencies", "source", "source_dir"] + # TODO: "postprocessing" is this part of min_set_props? for Dl in [StringSeqIntervalDl, SeqIntervalDl]: props = dir(Dl) @@ -56,11 +58,14 @@ def test_fasta_based_dataset(intervals_file, fasta_file): vals = dl.load_all() assert vals['inputs'][0] == 'GT' + def test_use_strand(intervals_file_strand, fasta_file): - dl = StringSeqIntervalDl(intervals_file_strand, fasta_file, use_strand=True) + dl = StringSeqIntervalDl(intervals_file_strand, + fasta_file, use_strand=True) vals = dl.load_all() assert vals['inputs'][0] == 'AC' + def test_seq_dataset(intervals_file, fasta_file): dl = SeqIntervalDl(intervals_file, fasta_file) ret_val = dl[0] diff --git a/tests/extractors/test_variant_combinations.py b/tests/extractors/test_variant_combinations.py new file mode 100644 index 0000000..f94be0f --- /dev/null +++ b/tests/extractors/test_variant_combinations.py @@ -0,0 +1,74 @@ +import pytest +from conftest import example_intervals_bed, sample_5kb_fasta_file +import pyranges as pr +from kipoiseq import Interval +from kipoiseq.extractors import VariantCombinator, MultiSampleVCF + + +@pytest.fixture +def variant_combinator(): + return VariantCombinator(sample_5kb_fasta_file, example_intervals_bed) + + +def test_VariantCombinator_combination_variants(variant_combinator): + interval = Interval('chr1', 20, 30) + variants = list(variant_combinator.combination_variants(interval, 'snv')) + assert len(variants) == 30 + + interval = Interval('chr1', 20, 22) + variants = list(variant_combinator.combination_variants(interval, 'snv')) + assert variants[0].chrom == 'chr1' + assert variants[0].ref == 'A' + assert variants[0].alt == 'C' + assert variants[1].alt == 'G' + assert variants[2].alt == 'T' + + assert variants[3].ref == 'C' + assert variants[3].alt == 'A' + assert variants[4].alt == 'G' + assert variants[5].alt == 'T' + + interval = Interval('chr1', 20, 22) + variants = list(variant_combinator.combination_variants(interval, 'in')) + len(variants) == 32 + assert variants[0].ref == 'A' + assert variants[0].alt == 'AA' + assert variants[15].alt == 'TT' + + assert variants[16].ref == 'C' + assert variants[16].alt == 'AA' + assert variants[31].alt == 'TT' + + interval = Interval('chr1', 20, 22) + variants = list(variant_combinator.combination_variants( + interval, 'del', del_length=2)) + assert len(variants) == 3 + assert variants[0].ref == 'A' + assert variants[0].alt == '' + assert variants[1].ref == 'AC' + assert variants[1].alt == '' + assert variants[2].ref == 'C' + assert variants[2].alt == '' + + variants = list(variant_combinator.combination_variants( + interval, 'all', in_length=2, del_length=2)) + assert len(variants) == 6 + 32 + 3 + + +def test_VariantCombinator_iter(variant_combinator): + variants = list(variant_combinator) + df = pr.read_bed(example_intervals_bed).merge(strand=False).df + num_snv = (df['End'] - df['Start']).sum() * 3 + assert len(variants) == num_snv + assert len(variants) == len(set(variants)) + + +def test_VariantCombinator_to_vcf(tmpdir, variant_combinator): + output_vcf_file = str(tmpdir / 'output.vcf') + variant_combinator.to_vcf(output_vcf_file) + + vcf = MultiSampleVCF(output_vcf_file) + + df = pr.read_bed(example_intervals_bed).merge(strand=False).df + num_snv = (df['End'] - df['Start']).sum() * 3 + assert len(list(vcf)) == num_snv diff --git a/tests/extractors/test_vcf_query.py b/tests/extractors/test_vcf_query.py index 8a353dc..b068622 100644 --- a/tests/extractors/test_vcf_query.py +++ b/tests/extractors/test_vcf_query.py @@ -1,5 +1,6 @@ import pytest from conftest import vcf_file +import pandas as pd from kipoiseq.dataclasses import Variant, Interval from kipoiseq.extractors.vcf_seq import MultiSampleVCF from kipoiseq.extractors.vcf_query import * @@ -137,3 +138,39 @@ def test_VariantQueryable_to_vcf(tmp_path): vcf = MultiSampleVCF(path) assert len(vcf.samples) == 0 + + +def test_VariantQueryable_to_sample_csv(tmp_path): + vcf = MultiSampleVCF(vcf_file) + + variant_queryable = vcf.query_all() + + path = str(tmp_path / 'sample.csv') + variant_queryable.to_sample_csv(path) + + df = pd.read_csv(path) + df_expected = pd.DataFrame({ + 'variant': ['chr1:4:T>C', 'chr1:25:AACG>GA'], + 'sample': ['NA00003', 'NA00002'], + 'genotype': [3, 3] + }) + pd.testing.assert_frame_equal(df, df_expected) + + +def test_VariantQueryable_to_sample_csv_fields(tmp_path): + vcf = MultiSampleVCF(vcf_file) + + variant_queryable = vcf.query_all() + + path = str(tmp_path / 'sample.csv') + variant_queryable.to_sample_csv(path, ['GT', 'HQ']) + + df = pd.read_csv(path) + df_expected = pd.DataFrame({ + 'variant': ['chr1:4:T>C', 'chr1:25:AACG>GA'], + 'sample': ['NA00003', 'NA00002'], + 'genotype': [3, 3], + 'GT': ['1/1', '1/1'], + 'HQ': ['51,51', '10,10'] + }) + pd.testing.assert_frame_equal(df, df_expected)