Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/PeptoneInc/ADOPT
Browse files Browse the repository at this point in the history
  • Loading branch information
CFisicaro committed Nov 16, 2021
2 parents 39f18ef + b72ac3c commit 74c65f0
Show file tree
Hide file tree
Showing 25 changed files with 693 additions and 273 deletions.
1 change: 1 addition & 0 deletions .github/workflows/linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ jobs:
IGNORE_GENERATED_FILES: true
VALIDATE_PYTHON_BLACK: false
VALIDATE_PYTHON_ISORT: false
FILTER_REGEX_EXCLUDE: /esm

3 changes: 2 additions & 1 deletion adopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"@generated"
from . import constants, utils
from .data import CheZod
from .inference import ZScorePred
from .training import DisorderPred
from .viz import get_multi_attention
from .transformer import MultiHead
from .version import version as __version__
14 changes: 14 additions & 0 deletions adopt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,17 @@
"train_on_cleared_1325_cv_residue_split": "cleared_residue_cv",
"train_on_cleared_1325_cv_sequence_split": "cleared_sequence_cv",
}

structure_dict = {
"Fully disordered": "FDIS",
"Partially disordered": "PDIS",
"Structured": "STRUCT",
"Flexible loops": "FLEX",
}

res_colors = {
"FDIS": "#FF3349",
"PDIS": "#FFD433",
"STRUCT": "#33C4FF",
"FLEX": "##fc9ce7",
}
1 change: 1 addition & 0 deletions adopt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import pandas as pd

from adopt import constants, utils


Expand Down
146 changes: 79 additions & 67 deletions adopt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,80 +6,91 @@
import getopt
import os
import sys

import numpy as np
import pandas as pd
import torch

from adopt import constants, utils


def get_z_score(
strategy,
model_type,
inference_fasta_path,
inference_repr_path,
predicted_z_scores_path,
):
df_fasta = utils.fasta_to_df(inference_fasta_path)

if model_type == "combined":
repr_path = inference_repr_path + "/" + "esm-1v"
else:
repr_path = inference_repr_path + "/" + model_type

repr_files = os.listdir(repr_path)
indexes = []

for file in repr_files:
indexes.append(file.split(".")[0])

onnx_model = (
"../models/lasso_"
+ model_type
+ "_"
+ constants.strategies_dict[strategy]
+ ".onnx"
)
predicted_z_scores = []

for ix in indexes:
if model_type == "esm-msa":
repr_esm = (
torch.load(str(repr_path) + "/" + ix + ".pt")["representations"][12]
.clone()
.cpu()
.detach()
)
elif model_type == "combined":
esm1b_repr_path = inference_repr_path + "/" + "esm-1b"
repr_esm1v = (
torch.load(str(repr_path) + "/" + ix + ".pt")["representations"][33]
.clone()
.cpu()
.detach()
)
repr_esm1b = (
torch.load(str(esm1b_repr_path) + "/" + ix + ".pt")["representations"][
33
]
.clone()
.cpu()
.detach()
)
repr_esm = torch.cat([repr_esm1v, repr_esm1b], 1)
class ZScorePred:
def __init__(self, strategy, model_type):
self.strategy = strategy
self.model_type = model_type
self.onnx_model = (
"../models/lasso_"
+ self.model_type
+ "_"
+ constants.strategies_dict[self.strategy]
+ ".onnx"
)

def get_z_score(self, representation):
predicted_z_scores = utils.get_onnx_model_preds(
self.onnx_model, representation.squeeze().numpy()
)
return np.concatenate(predicted_z_scores)

def get_z_score_from_fasta(
self,
inference_fasta_path,
inference_repr_path,
predicted_z_scores_path,
):
df_fasta = utils.fasta_to_df(inference_fasta_path)

if self.model_type == "combined":
repr_path = inference_repr_path + "/" + "esm-1v"
else:
repr_esm = (
torch.load(str(repr_path) + "/" + ix + ".pt")["representations"][33]
.clone()
.cpu()
.detach()
)
z_scores = utils.get_onnx_model_preds(onnx_model, repr_esm.numpy())
predicted_z_scores.append(np.concatenate(z_scores))
repr_path = inference_repr_path + "/" + self.model_type

repr_files = os.listdir(repr_path)
indexes = []

for file in repr_files:
indexes.append(file.split(".")[0])

predicted_z_scores = []

for ix in indexes:
if self.model_type == "esm-msa":
repr_esm = (
torch.load(str(repr_path) + "/" + ix + ".pt")["representations"][12]
.clone()
.cpu()
.detach()
)
elif self.model_type == "combined":
esm1b_repr_path = inference_repr_path + "/" + "esm-1b"
repr_esm1v = (
torch.load(str(repr_path) + "/" + ix + ".pt")["representations"][33]
.clone()
.cpu()
.detach()
)
repr_esm1b = (
torch.load(str(esm1b_repr_path) + "/" + ix + ".pt")[
"representations"
][33]
.clone()
.cpu()
.detach()
)
repr_esm = torch.cat([repr_esm1v, repr_esm1b], 1)
else:
repr_esm = (
torch.load(str(repr_path) + "/" + ix + ".pt")["representations"][33]
.clone()
.cpu()
.detach()
)
z_scores = utils.get_onnx_model_preds(self.onnx_model, repr_esm.numpy())
predicted_z_scores.append(np.concatenate(z_scores))

df_z = pd.DataFrame({"brmid": indexes, "z_scores": predicted_z_scores})
df_results = df_fasta.join(df_z.set_index("brmid"), on="brmid")
df_results.to_json(predicted_z_scores_path, orient="records")
df_z = pd.DataFrame({"brmid": indexes, "z_scores": predicted_z_scores})
df_results = df_fasta.join(df_z.set_index("brmid"), on="brmid")
df_results.to_json(predicted_z_scores_path, orient="records")


def main(argv):
Expand Down Expand Up @@ -144,8 +155,9 @@ def main(argv):
elif opt in ("-p", "--pred_z_scores_file"):
pred_z_scores_file = arg

get_z_score(
train_strategy, model_type, infer_fasta_file, infer_repr_dir, pred_z_scores_file
z_score_pred = ZScorePred(train_strategy, model_type)
z_score_pred.get_z_score_from_fasta(
infer_fasta_file, infer_repr_dir, pred_z_scores_file
)


Expand Down
53 changes: 53 additions & 0 deletions adopt/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2021 Peptone.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys

import torch

from adopt import constants, utils


class MultiHead:
def __init__(self, model_type, sequence, brmid):
self.model_type = model_type
self.sequence = sequence
self.brmid = brmid
self.data = [(self.brmid, self.sequence)]

def get_attention(self):
if self.model_type in constants.model_types:
results = utils.get_model_and_alphabet(self.model_type, self.data)
else:
print("The model types are:")
print(*constants.model_types, sep="\n")
sys.exit(2)

tokens = list(self.sequence)
attention = results["attentions"].permute(1, 0, 2, 3, 4)
# remove first and last token (<cls> and <sep>)
attention = attention[:, :, :, 1:-1, 1:-1]
return attention, tokens

def get_representation(self):
if self.model_type in constants.model_types:
results = utils.get_model_and_alphabet(self.model_type, self.data)
representation = results["representations"][33]
elif self.model_type == "combined":
results_esm1b, results_esm1v = utils.get_model_and_alphabet(
self.model_type, self.data
)
representation_esm1b = results_esm1b["representations"][33]
representation_esm1v = results_esm1v["representations"][33]
representation = torch.cat((representation_esm1b, representation_esm1v), -1)
else:
print("The model types are:")
print(*constants.model_types, sep="\n")
sys.exit(2)

tokens = list(self.sequence)
# remove first and last token (<cls> and <sep>)
representation = representation[:, 1:-1, :]
return representation, tokens
46 changes: 41 additions & 5 deletions adopt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

import esm
from adopt import constants


"@generated"
# throw away the missing values, if the drop_missing flag is set to True, i.e. where z-scores are 999
def pedestrian_input(indexes, df, path, z_col="z-score", msa=False, drop_missing=True):
zeds = []
Expand Down Expand Up @@ -91,13 +92,48 @@ def get_onnx_model_preds(model_name, input_data):
return pred_onx


def get_esm_attention(model, alphabet, sequence, brmid):
def get_esm_output(model, alphabet, data):
batch_converter = alphabet.get_batch_converter()

data = [(brmid, sequence)]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

# Extract per-residue representations (on CPU)
with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=True)
return results


def get_model_and_alphabet(model_type, data):
# Load ESM model
if model_type == "esm-1b":
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
results = get_esm_output(model, alphabet, data)
elif model_type == "esm-1v":
model, alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1()
results = get_esm_output(model, alphabet, data)
# elif model_type == 'esm-msa':
# model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S
else:
model_esm1b, alphabet_esm1b = esm.pretrained.esm1b_t33_650M_UR50S()
model_esm1v, alphabet_esm1v = esm.pretrained.esm1v_t33_650M_UR90S_1()
results_esm1b = get_esm_output(model_esm1b, alphabet_esm1b, data)
results_esm1v = get_esm_output(model_esm1v, alphabet_esm1v, data)
results = [results_esm1b, results_esm1v]
return results


def get_residue_class(predicted_z_scores):
residues_state = []
for n, zscore in enumerate(predicted_z_scores):
residues_dict = {}
if zscore < 3:
residues_dict["label"] = constants.structure_dict["Fully disordered"]
elif 3 <= zscore < 8:
residues_dict["label"] = constants.structure_dict["Partially disordered"]
elif zscore >= 11:
residues_dict["label"] = constants.structure_dict["Structured"]
else:
residues_dict["label"] = constants.structure_dict["Flexible loops"]

residues_dict["start"] = n
residues_dict["end"] = n + 1
residues_state.append(residues_dict)
return residues_state
43 changes: 0 additions & 43 deletions adopt/viz.py

This file was deleted.

7 changes: 3 additions & 4 deletions esm/esm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .version import version as __version__ # noqa

from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa
from .model import ProteinBertModel, MSATransformer # noqa
from . import pretrained # noqa
from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa
from .model import MSATransformer, ProteinBertModel # noqa
from .version import version as __version__ # noqa
Loading

0 comments on commit 74c65f0

Please sign in to comment.