Skip to content

Commit

Permalink
bug fix #139 and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
MuhammedHasan committed Jan 28, 2019
1 parent 0b45364 commit ddb6455
Show file tree
Hide file tree
Showing 22 changed files with 272 additions and 859 deletions.
161 changes: 8 additions & 153 deletions MMSplice/deltaLogitPSI/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
import numpy as np
from kipoi.data import SampleIterator

import pickle
from pyfaidx import Fasta
from cyvcf2 import VCF
from concise.preprocessing import encodeDNA
import warnings

from mmsplice import MMSplice
from mmsplice.vcf_dataloader import GenerateExonIntervalTree, VariantInterval, get_var_side
from mmsplice.vcf_dataloader import SplicingVCFDataloader as BaseSplicingVCFDataloader

model = MMSplice(
exon_cut_l=0,
Expand All @@ -20,148 +11,12 @@
donor_exon_len=5,
donor_intron_len=13)

class SplicingVCFDataloader(SampleIterator):
"""
Load genome annotation (gtf) file along with a vcf file, return wt sequence and mut sequence.
Args:
gtf: gtf file or pickled gtf IntervalTree.
fasta_file: file path; Genome sequence
vcf_file: file path; vcf file with variants to score
"""

def __init__(self,
gtf_file,
fasta_file,
vcf_file,
split_seq=False,
encode=True,
exon_cut_l=0,
exon_cut_r=0,
acceptor_intron_cut=6,
donor_intron_cut=6,
acceptor_intron_len=50,
acceptor_exon_len=3,
donor_exon_len=5,
donor_intron_len=13,
**kwargs
):
try:
with open(gtf, 'rb') as f:
self.exons = pickle.load(f)
except:
self.exons = GenerateExonIntervalTree(gtf_file, **kwargs)
import six
if isinstance(fasta_file, six.string_types):
fasta = Fasta(fasta_file, as_raw=False)
self.fasta = fasta
self.ssGenerator = self.spliceSiteGenerator(vcf_file, self.exons)

self.encode = encode
self.split_seq = split_seq
self.exon_cut_l = exon_cut_l
self.exon_cut_r = exon_cut_r
self.acceptor_intron_cut = acceptor_intron_cut
self.donor_intron_cut = donor_intron_cut
self.acceptor_intron_len = acceptor_intron_len
self.acceptor_exon_len = acceptor_exon_len
self.donor_exon_len = donor_exon_len
self.donor_intron_len = donor_intron_len

@staticmethod
def spliceSiteGenerator(vcf_file, exonTree, variant_filter=True):
variants = VCF(vcf_file)
for var in variants:
if variant_filter and var.FILTER:
next
iv = VariantInterval.from_Variant(var)

matches = map(lambda x: x.interval,
exonTree.intersect(iv, ignore_strand=True))

for match in matches:
side = get_var_side((
var.POS,
var.REF,
var.ALT,
match.Exon_Start,
match.Exon_End,
match.strand
))
var = iv.to_Variant(match.strand, side) # to my Variant class
yield match, var

def __iter__(self):
return self

class SplicingVCFDataloader(BaseSplicingVCFDataloader):
def __next__(self):
ss, var = next(self.ssGenerator)
out = {}
x = {}
x['inputs'] = {}
x['inputs_mut'] = {}
seq = ss.get_seq(self.fasta).upper()
mut_seq = ss.get_mut_seq(self.fasta, var).upper()
if self.split_seq:
seq = self.split(seq, ss.overhang)
mut_seq = self.split(mut_seq, ss.overhang)
x['inputs']['seq'] = seq
x['inputs_mut']['seq'] = mut_seq
x['inputs']['intronl_len'] = ss.overhang[0]
x['inputs']['intronr_len'] = ss.overhang[1]
x['inputs_mut']['intronl_len'] = ss.overhang[0]
x['inputs_mut']['intronr_len'] = ss.overhang[1]

out['inputs'] = (model.predict(x['inputs_mut']) - model.predict(x['inputs'])).values

out['metadata'] = {}
out['metadata']['ranges'] = ss.grange
out['metadata']['variant'] = var.to_dict
out['metadata']['ExonInterval'] = ss.to_dict # so that np collate will work
out['metadata']['annotation'] = str(ss)
return out

def batch_predict_iter(self, **kwargs):
"""Returns samples directly useful for prediction x["inputs"]
Args:
**kwargs: Arguments passed to self.batch_iter(**kwargs)
"""
return (x for x in self.batch_iter(**kwargs))

def split(self, x, overhang):
''' x: a sequence to split
'''
intronl_len, intronr_len = overhang
lackl = self.acceptor_intron_len - intronl_len # need to pad N if left seq not enough long
if lackl >= 0:
x = "N"*(lackl+1) + x
intronl_len += lackl+1
lackr = self.donor_intron_len - intronr_len
if lackr >= 0:
x = x + "N"*(lackr+1)
intronr_len += lackr + 1
acceptor_intron = x[:intronl_len-self.acceptor_intron_cut]
acceptor = x[(intronl_len-self.acceptor_intron_len) : (intronl_len+self.acceptor_exon_len)]
exon = x[(intronl_len+self.exon_cut_l) : (-intronr_len-self.exon_cut_r)]
donor = x[(-intronr_len-self.donor_exon_len) : (-intronr_len+self.donor_intron_len)]
donor_intron = x[-intronr_len+self.donor_intron_cut:]
if donor[self.donor_exon_len:self.donor_exon_len+2] != "GT":
warnings.warn("None GT donor", UserWarning)
if acceptor[self.acceptor_intron_len-2:self.acceptor_intron_len] != "AG":
warnings.warn("None AG donor", UserWarning)

if self.encode:
return {
"acceptor_intron": encodeDNA([acceptor_intron]),
"acceptor": encodeDNA([acceptor]),
"exon": encodeDNA([exon]),
"donor": encodeDNA([donor]),
"donor_intron": encodeDNA([donor_intron])
}
else:
return {
"acceptor_intron": acceptor_intron,
"acceptor": acceptor,
"exon": exon,
"donor": donor,
"donor_intron": donor_intron
}
super_out = super().__next__()
return {
'inputs': (model.predict(super_out['inputs_mut']) -
model.predict(super_out['inputs'])).values,
'metadata': super_out['metadata']
}
9 changes: 6 additions & 3 deletions MMSplice/deltaLogitPSI/dataloader.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
type: SampleIterator
defined_as: dataloader.SplicingVCFDataloader
args:
gtf_file:
gtf:
doc: path to the GTF file required by the models (Ensemble)
example:
url: https://sandbox.zenodo.org/record/248604/files/test.gtf?download=1
Expand All @@ -19,6 +19,9 @@ args:
split_seq:
doc: Whether split the sequence in dataloader
optional: True
variant_filter:
doc: If set True (default), variants with `FILTER` field other than `PASS` will be filtered out.
optional: True
encode:
doc: If split the sequence, whether one hot encoding
optional: True
Expand Down Expand Up @@ -73,7 +76,7 @@ dependencies:
- python=3.5
pip:
- kipoi
- mmsplice
- mmsplice>=0.2.7
output_schema:
inputs:
shape: (5, )
Expand Down Expand Up @@ -157,4 +160,4 @@ output_schema:
doc: genomic end position of the retrieved sequence
annotation:
type: str
doc: retrieved sequence name
doc: retrieved sequence name
12 changes: 3 additions & 9 deletions MMSplice/deltaLogitPSI/model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
from kipoi.model import BaseModel
import numpy as np
from mmsplice import LINEAR_MODEL
from mmsplice.utils.postproc import transform

# Model to predict delta logit PSI

class MMSpliceModel(BaseModel):

def __init__(self):
'''Model to predict delta logit PSI'''

self.model = LINEAR_MODEL

def predict_on_batch(self, inputs):
X = transform(inputs, False)
pred = self.model.predict(X)
return pred
return LINEAR_MODEL.predict(transform(inputs, False))
1 change: 1 addition & 0 deletions MMSplice/deltaLogitPSI/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- numpy
pip:
- scikit-learn
- mmsplice>=0.2.7
schema:
inputs:
shape: (5, )
Expand Down
Loading

0 comments on commit ddb6455

Please sign in to comment.