-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from bpnet.seqmodel import SeqModel | ||
from keras.models import load_model | ||
import numpy as np | ||
import bpnet | ||
import tensorflow as tf | ||
from bpnet.functions import softmax | ||
import keras.backend as K | ||
import keras.layers as kl | ||
from kipoi.model import BaseModel | ||
|
||
def profile_contrib(p): | ||
return kl.Lambda(lambda p: | ||
K.mean(K.sum(K.stop_gradient(tf.nn.softmax(p, dim=-2)) * p, axis=-2), axis=-1) | ||
)(p) | ||
|
||
|
||
class BPNetOldSeqModel(BaseModel, SeqModel): | ||
|
||
preact_tensor_names = ['reshape_2/Reshape:0', | ||
'dense_1/BiasAdd:0', | ||
'reshape_4/Reshape:0', | ||
'dense_3/BiasAdd:0', | ||
'reshape_6/Reshape:0', | ||
'dense_5/BiasAdd:0', | ||
'reshape_8/Reshape:0', | ||
'dense_7/BiasAdd:0' | ||
] | ||
|
||
bottleneck_name = 'add_9/add:0' | ||
|
||
target_names = ['Oct4/profile', | ||
'Oct4/counts', | ||
'Sox2/profile', | ||
'Sox2/counts', | ||
'Nanog/profile', | ||
'Nanog/counts', | ||
'Klf4/profile', | ||
'Klf4/counts'] | ||
|
||
seqlen = 1000 | ||
|
||
tasks = ['Oct4', 'Sox2', 'Nanog', 'Klf4'] | ||
|
||
postproc_fns = [softmax, None] * 4 | ||
|
||
def __init__(self, model_file): | ||
self.model_file = model_file | ||
K.clear_session() # restart session | ||
self.model = load_model(model_file, compile=False) | ||
self.contrib_fns = {} | ||
|
||
def predict_on_batch(self, seq): | ||
preds = self.model.predict_on_batch({"seq": seq, **self.neutral_bias_inputs(len(seq), seqlen=seq.shape[1])}) | ||
pred_dict = {target: preds[i] for i, target in enumerate(self.target_names)} | ||
return {task: softmax(pred_dict[f'{task}/profile']) * np.exp(pred_dict[f'{task}/counts'][:, np.newaxis]) | ||
for task in self.tasks} | ||
|
||
def neutral_bias_inputs(self, length, seqlen): | ||
"""Compile a set of neutral bias inputs | ||
""" | ||
return dict([('bias/' + target, np.zeros((length, seqlen, 4)) | ||
if target.endswith("/profile") | ||
else np.zeros((length, 2))) | ||
for target in self.target_names]) | ||
|
||
def get_intp_tensors(self, preact_only=True, graph=None): | ||
if graph is None: | ||
graph = tf.get_default_graph() | ||
intp_targets = [] | ||
for head_name, tensor_name in zip(self.target_names, self.preact_tensor_names): | ||
tensor = graph.get_tensor_by_name(tensor_name) | ||
if head_name.endswith("/profile"): | ||
tensor = profile_contrib(tensor) | ||
intp_targets.append((head_name, tensor)) | ||
return intp_targets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
defined_as: model.BPNetOldSeqModel | ||
args: | ||
model_file: | ||
# TODO - put to Zenodo | ||
url: 'http://mitra.stanford.edu/kundaje/avsec/chipnexus/paper/modisco-comparison/v2-output/nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE/model.calibrated.h5' | ||
md5: bbe883baef261877bfad07d05feb627d | ||
|
||
default_dataloader: | ||
defined_as: kipoiseq.dataloaders.SeqIntervalDl | ||
default_args: | ||
auto_resize_len: 1000 | ||
ignore_targets: True | ||
|
||
info: | ||
authors: | ||
- name: Ziga Avsec | ||
github: avsecz | ||
doc: BPNet model predicting the ChIP-nexus profiles of Oct4, Sox2, Nanog and Klf4 | ||
cite_as: TODO | ||
trained_on: ChIP-nexus data in mm10. test chromosomes 1, 8, 9, validation chromosomes 2, 3, 4 | ||
license: MIT | ||
|
||
dependencies: | ||
channels: | ||
- bioconda | ||
- pytorch | ||
- conda-forge | ||
- defaults | ||
conda: | ||
- python=3.6 | ||
- bioconda::pybedtools>=0.7.10 | ||
- bioconda::bedtools>=2.27.1 | ||
- bioconda::pybigwig>=0.3.10 | ||
- bioconda::pysam>=0.14.0 | ||
- bioconda::genomelake>=0.1.4 | ||
|
||
- pytorch::pytorch # optional for data-loading | ||
- cython | ||
- h5py>=2.7.0 | ||
- numpy | ||
|
||
- pandas>=0.23.0 | ||
- fastparquet | ||
- python-snappy | ||
|
||
- nb_conda | ||
pip: | ||
- tensorflow>=1.0 | ||
- git+https://github.com/kundajelab/DeepExplain.git | ||
- bpnet[extras] | ||
schema: | ||
inputs: | ||
shape: (1000, 4) | ||
doc: "One-hot encoded DNA sequence." | ||
targets: | ||
Oct4: | ||
shape: (1000,2) | ||
doc: "Strand-specific ChIP-nexus data for Oct4." | ||
Sox2: | ||
shape: (1000,2) | ||
doc: "Strand-specific ChIP-nexus data for Sox2." | ||
Nanog: | ||
shape: (1000,2) | ||
doc: "Strand-specific ChIP-nexus data for Nanog." | ||
Klf4: | ||
shape: (1000,2) | ||
doc: "Strand-specific ChIP-nexus data for Klf4." |