Skip to content

Commit

Permalink
added bpnet model (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
Avsecz authored Aug 21, 2019
1 parent a49b9c1 commit fe3c463
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 0 deletions.
75 changes: 75 additions & 0 deletions BPNet-OSKN/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from bpnet.seqmodel import SeqModel
from keras.models import load_model
import numpy as np
import bpnet
import tensorflow as tf
from bpnet.functions import softmax
import keras.backend as K
import keras.layers as kl
from kipoi.model import BaseModel

def profile_contrib(p):
return kl.Lambda(lambda p:
K.mean(K.sum(K.stop_gradient(tf.nn.softmax(p, dim=-2)) * p, axis=-2), axis=-1)
)(p)


class BPNetOldSeqModel(BaseModel, SeqModel):

preact_tensor_names = ['reshape_2/Reshape:0',
'dense_1/BiasAdd:0',
'reshape_4/Reshape:0',
'dense_3/BiasAdd:0',
'reshape_6/Reshape:0',
'dense_5/BiasAdd:0',
'reshape_8/Reshape:0',
'dense_7/BiasAdd:0'
]

bottleneck_name = 'add_9/add:0'

target_names = ['Oct4/profile',
'Oct4/counts',
'Sox2/profile',
'Sox2/counts',
'Nanog/profile',
'Nanog/counts',
'Klf4/profile',
'Klf4/counts']

seqlen = 1000

tasks = ['Oct4', 'Sox2', 'Nanog', 'Klf4']

postproc_fns = [softmax, None] * 4

def __init__(self, model_file):
self.model_file = model_file
K.clear_session() # restart session
self.model = load_model(model_file, compile=False)
self.contrib_fns = {}

def predict_on_batch(self, seq):
preds = self.model.predict_on_batch({"seq": seq, **self.neutral_bias_inputs(len(seq), seqlen=seq.shape[1])})
pred_dict = {target: preds[i] for i, target in enumerate(self.target_names)}
return {task: softmax(pred_dict[f'{task}/profile']) * np.exp(pred_dict[f'{task}/counts'][:, np.newaxis])
for task in self.tasks}

def neutral_bias_inputs(self, length, seqlen):
"""Compile a set of neutral bias inputs
"""
return dict([('bias/' + target, np.zeros((length, seqlen, 4))
if target.endswith("/profile")
else np.zeros((length, 2)))
for target in self.target_names])

def get_intp_tensors(self, preact_only=True, graph=None):
if graph is None:
graph = tf.get_default_graph()
intp_targets = []
for head_name, tensor_name in zip(self.target_names, self.preact_tensor_names):
tensor = graph.get_tensor_by_name(tensor_name)
if head_name.endswith("/profile"):
tensor = profile_contrib(tensor)
intp_targets.append((head_name, tensor))
return intp_targets
67 changes: 67 additions & 0 deletions BPNet-OSKN/model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
defined_as: model.BPNetOldSeqModel
args:
model_file:
# TODO - put to Zenodo
url: 'http://mitra.stanford.edu/kundaje/avsec/chipnexus/paper/modisco-comparison/v2-output/nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE/model.calibrated.h5'
md5: bbe883baef261877bfad07d05feb627d

default_dataloader:
defined_as: kipoiseq.dataloaders.SeqIntervalDl
default_args:
auto_resize_len: 1000
ignore_targets: True

info:
authors:
- name: Ziga Avsec
github: avsecz
doc: BPNet model predicting the ChIP-nexus profiles of Oct4, Sox2, Nanog and Klf4
cite_as: TODO
trained_on: ChIP-nexus data in mm10. test chromosomes 1, 8, 9, validation chromosomes 2, 3, 4
license: MIT

dependencies:
channels:
- bioconda
- pytorch
- conda-forge
- defaults
conda:
- python=3.6
- bioconda::pybedtools>=0.7.10
- bioconda::bedtools>=2.27.1
- bioconda::pybigwig>=0.3.10
- bioconda::pysam>=0.14.0
- bioconda::genomelake>=0.1.4

- pytorch::pytorch # optional for data-loading
- cython
- h5py>=2.7.0
- numpy

- pandas>=0.23.0
- fastparquet
- python-snappy

- nb_conda
pip:
- tensorflow>=1.0
- git+https://github.com/kundajelab/DeepExplain.git
- bpnet[extras]
schema:
inputs:
shape: (1000, 4)
doc: "One-hot encoded DNA sequence."
targets:
Oct4:
shape: (1000,2)
doc: "Strand-specific ChIP-nexus data for Oct4."
Sox2:
shape: (1000,2)
doc: "Strand-specific ChIP-nexus data for Sox2."
Nanog:
shape: (1000,2)
doc: "Strand-specific ChIP-nexus data for Nanog."
Klf4:
shape: (1000,2)
doc: "Strand-specific ChIP-nexus data for Klf4."

0 comments on commit fe3c463

Please sign in to comment.