Skip to content

Commit

Permalink
variant combinator (#94)
Browse files Browse the repository at this point in the history
* variant combinator

* variant combinator vcf

* bug fix upper case

* sample from vcf

* format fields bug fix

* sort_intervals

* update on testcase

Co-authored-by: M. Hasan Celik <[email protected]>
  • Loading branch information
MuhammedHasan and M. Hasan Celik authored Feb 10, 2022
1 parent 8893483 commit e234baf
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 11 deletions.
1 change: 1 addition & 0 deletions kipoiseq/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .vcf_matching import *
from .multi_interval import *
from .protein import *
from .variant_combinations import *
98 changes: 98 additions & 0 deletions kipoiseq/extractors/variant_combinations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Iterable
from itertools import product
from kipoiseq import Interval, Variant
from kipoiseq.utils import alphabets
from kipoiseq.extractors import FastaStringExtractor
from kipoiseq.extractors.vcf_matching import pyranges_to_intervals


class VariantCombinator:

def __init__(self, fasta_file: str, bed_file: str = None,
variant_type='snv', alphabet='DNA'):
if variant_type not in {'all', 'snv', 'in', 'del'}:
raise ValueError("variant_type should be one of "
"{'all', 'snv', 'in', 'del'}")

self.bed_file = bed_file
self.fasta = fasta_file
self.fasta = FastaStringExtractor(fasta_file, force_upper=True)
self.variant_type = variant_type
self.alphabet = alphabets[alphabet]

def combination_variants_snv(self, interval: Interval) -> Iterable[Variant]:
"""Returns all the possible variants in the regions.
interval: interval of variants
"""
seq = self.fasta.extract(interval)
for pos, ref in zip(range(interval.start, interval.end), seq):
pos = pos + 1 # 0 to 1 base
for alt in self.alphabet:
if ref != alt:
yield Variant(interval.chrom, pos, ref, alt)

def combination_variants_insertion(self, interval, length=2) -> Iterable[Variant]:
"""Returns all the possible variants in the regions.
interval: interval of variants
length: insertions up to length
"""
if length < 2:
raise ValueError('length argument should be larger than 1')

seq = self.fasta.extract(interval)
for pos, ref in zip(range(interval.start, interval.end), seq):
pos = pos + 1 # 0 to 1 base
for l in range(2, length + 1):
for alt in product(self.alphabet, repeat=l):
yield Variant(interval.chrom, pos, ref, ''.join(alt))

def combination_variants_deletion(self, interval, length=1) -> Iterable[Variant]:
"""Returns all the possible variants in the regions.
interval: interval of variants
length: deletions up to length
"""
if length < 1 and length <= interval.width:
raise ValueError('length argument should be larger than 0'
' and smaller than interval witdh')

seq = self.fasta.extract(interval)
for i, pos in enumerate(range(interval.start, interval.end)):
pos = pos + 1 # 0 to 1 base
for j in range(1, length + 1):
if i + j <= len(seq):
yield Variant(interval.chrom, pos, seq[i:i + j], '')

def combination_variants(self, interval, variant_type='snv',
in_length=2, del_length=2) -> Iterable[Variant]:
if variant_type in {'snv', 'all'}:
yield from self.combination_variants_snv(interval)
if variant_type in {'indel', 'in', 'all'}:
yield from self.combination_variants_insertion(
interval, length=in_length)
if variant_type in {'indel', 'del', 'all'}:
yield from self.combination_variants_deletion(
interval, length=del_length)

def __iter__(self) -> Iterable[Variant]:
import pyranges as pr

gr = pr.read_bed(self.bed_file)
gr = gr.merge(strand=False).sort()

for interval in pyranges_to_intervals(gr):
yield from self.combination_variants(interval, self.variant_type)

def to_vcf(self, path):
from cyvcf2 import Writer
header = '''##fileformat=VCFv4.2
#CHROM POS ID REF ALT QUAL FILTER INFO
'''
writer = Writer.from_string(path, header)

for v in self:
variant = writer.variant_from_string('\t'.join([
v.chrom, str(v.pos), '.', v.ref, v.alt, '.', '.', '.'
]))
writer.write_record(variant)
38 changes: 38 additions & 0 deletions kipoiseq/extractors/vcf_query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import csv
import abc
from itertools import islice
from typing import Tuple, Iterable, List
Expand Down Expand Up @@ -241,3 +242,40 @@ def to_vcf(self, path, remove_samples=False, clean_info=False):
variant = writer.variant_from_string('\t'.join(variant))

writer.write_record(variant)

def to_sample_csv(self, path, format_fields=None):
"""
Extract samples and FORMAT from vcf then save as csv file.
"""
format_fields = format_fields or list()
writer = None

with open(path, 'w') as f:

for variant in self:
variant_fields = str(variant.source).strip().split('\t')

if writer is None:
# FORMAT field
fieldnames = ['variant', 'sample',
'genotype'] + format_fields
format_fields = set(format_fields)
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
samples = self.vcf.samples

values = dict(zip(samples, map(
lambda x: x.split(':'), variant_fields[9:])))
fields = variant_fields[8].split(':')

for sample, gt in self.vcf.get_samples(variant).items():
row = dict()
row['variant'] = str(variant)
row['sample'] = sample
row['genotype'] = gt

for k, v in zip(fields, values[sample]):
if k in format_fields:
row[k] = v

writer.writerow(row)
2 changes: 1 addition & 1 deletion kipoiseq/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def tokenize(seq, alphabet=DNA, neutral_alphabet=["N"]):
neutral_alphabet = [neutral_alphabet]

nchar = len(alphabet[0])
for l in alphabet + neutral_alphabet:
for l in (*alphabet, *neutral_alphabet):
assert len(l) == nchar
assert len(seq) % nchar == 0 # since we are using striding

Expand Down
13 changes: 6 additions & 7 deletions kipoiseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
import numpy as np
from six import string_types

# alphabets:
from kipoiseq import Variant

DNA = ["A", "C", "G", "T"]
RNA = ["A", "C", "G", "U"]
AMINO_ACIDS = ["A", "R", "N", "D", "B", "C", "E", "Q", "Z", "G", "H",
"I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]
DNA = ("A", "C", "G", "T")
RNA = ("A", "C", "G", "U")
AMINO_ACIDS = ("A", "R", "N", "D", "B", "C", "E", "Q", "Z", "G", "H",
"I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V")

alphabets = {"DNA": DNA,
"RNA": RNA,
Expand Down Expand Up @@ -38,7 +36,8 @@ def parse_dtype(dtype):
try:
return eval(dtype)
except Exception as e:
raise ValueError("Unable to parse dtype: {}. \nException: {}".format(dtype, e))
raise ValueError(
"Unable to parse dtype: {}. \nException: {}".format(dtype, e))
else:
return dtype

Expand Down
11 changes: 8 additions & 3 deletions tests/dataloaders/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def fasta_file():
def intervals_file():
return "tests/data/sample_intervals.bed"


@pytest.fixture
def intervals_file_strand():
return "tests/data/sample_interval_strand.bed"
Expand All @@ -25,8 +26,9 @@ def intervals_file_strand():

def test_min_props():
# minimal set of properties that need to be specified on the object
min_set_props = ["output_schema", "type", "defined_as", "info", "args", "dependencies", "postprocessing",
"source", "source_dir"]
min_set_props = ["output_schema", "type", "defined_as", "info", "args",
"dependencies", "source", "source_dir"]
# TODO: "postprocessing" is this part of min_set_props?

for Dl in [StringSeqIntervalDl, SeqIntervalDl]:
props = dir(Dl)
Expand Down Expand Up @@ -56,11 +58,14 @@ def test_fasta_based_dataset(intervals_file, fasta_file):
vals = dl.load_all()
assert vals['inputs'][0] == 'GT'


def test_use_strand(intervals_file_strand, fasta_file):
dl = StringSeqIntervalDl(intervals_file_strand, fasta_file, use_strand=True)
dl = StringSeqIntervalDl(intervals_file_strand,
fasta_file, use_strand=True)
vals = dl.load_all()
assert vals['inputs'][0] == 'AC'


def test_seq_dataset(intervals_file, fasta_file):
dl = SeqIntervalDl(intervals_file, fasta_file)
ret_val = dl[0]
Expand Down
74 changes: 74 additions & 0 deletions tests/extractors/test_variant_combinations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
from conftest import example_intervals_bed, sample_5kb_fasta_file
import pyranges as pr
from kipoiseq import Interval
from kipoiseq.extractors import VariantCombinator, MultiSampleVCF


@pytest.fixture
def variant_combinator():
return VariantCombinator(sample_5kb_fasta_file, example_intervals_bed)


def test_VariantCombinator_combination_variants(variant_combinator):
interval = Interval('chr1', 20, 30)
variants = list(variant_combinator.combination_variants(interval, 'snv'))
assert len(variants) == 30

interval = Interval('chr1', 20, 22)
variants = list(variant_combinator.combination_variants(interval, 'snv'))
assert variants[0].chrom == 'chr1'
assert variants[0].ref == 'A'
assert variants[0].alt == 'C'
assert variants[1].alt == 'G'
assert variants[2].alt == 'T'

assert variants[3].ref == 'C'
assert variants[3].alt == 'A'
assert variants[4].alt == 'G'
assert variants[5].alt == 'T'

interval = Interval('chr1', 20, 22)
variants = list(variant_combinator.combination_variants(interval, 'in'))
len(variants) == 32
assert variants[0].ref == 'A'
assert variants[0].alt == 'AA'
assert variants[15].alt == 'TT'

assert variants[16].ref == 'C'
assert variants[16].alt == 'AA'
assert variants[31].alt == 'TT'

interval = Interval('chr1', 20, 22)
variants = list(variant_combinator.combination_variants(
interval, 'del', del_length=2))
assert len(variants) == 3
assert variants[0].ref == 'A'
assert variants[0].alt == ''
assert variants[1].ref == 'AC'
assert variants[1].alt == ''
assert variants[2].ref == 'C'
assert variants[2].alt == ''

variants = list(variant_combinator.combination_variants(
interval, 'all', in_length=2, del_length=2))
assert len(variants) == 6 + 32 + 3


def test_VariantCombinator_iter(variant_combinator):
variants = list(variant_combinator)
df = pr.read_bed(example_intervals_bed).merge(strand=False).df
num_snv = (df['End'] - df['Start']).sum() * 3
assert len(variants) == num_snv
assert len(variants) == len(set(variants))


def test_VariantCombinator_to_vcf(tmpdir, variant_combinator):
output_vcf_file = str(tmpdir / 'output.vcf')
variant_combinator.to_vcf(output_vcf_file)

vcf = MultiSampleVCF(output_vcf_file)

df = pr.read_bed(example_intervals_bed).merge(strand=False).df
num_snv = (df['End'] - df['Start']).sum() * 3
assert len(list(vcf)) == num_snv
37 changes: 37 additions & 0 deletions tests/extractors/test_vcf_query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from conftest import vcf_file
import pandas as pd
from kipoiseq.dataclasses import Variant, Interval
from kipoiseq.extractors.vcf_seq import MultiSampleVCF
from kipoiseq.extractors.vcf_query import *
Expand Down Expand Up @@ -137,3 +138,39 @@ def test_VariantQueryable_to_vcf(tmp_path):

vcf = MultiSampleVCF(path)
assert len(vcf.samples) == 0


def test_VariantQueryable_to_sample_csv(tmp_path):
vcf = MultiSampleVCF(vcf_file)

variant_queryable = vcf.query_all()

path = str(tmp_path / 'sample.csv')
variant_queryable.to_sample_csv(path)

df = pd.read_csv(path)
df_expected = pd.DataFrame({
'variant': ['chr1:4:T>C', 'chr1:25:AACG>GA'],
'sample': ['NA00003', 'NA00002'],
'genotype': [3, 3]
})
pd.testing.assert_frame_equal(df, df_expected)


def test_VariantQueryable_to_sample_csv_fields(tmp_path):
vcf = MultiSampleVCF(vcf_file)

variant_queryable = vcf.query_all()

path = str(tmp_path / 'sample.csv')
variant_queryable.to_sample_csv(path, ['GT', 'HQ'])

df = pd.read_csv(path)
df_expected = pd.DataFrame({
'variant': ['chr1:4:T>C', 'chr1:25:AACG>GA'],
'sample': ['NA00003', 'NA00002'],
'genotype': [3, 3],
'GT': ['1/1', '1/1'],
'HQ': ['51,51', '10,10']
})
pd.testing.assert_frame_equal(df, df_expected)

0 comments on commit e234baf

Please sign in to comment.