-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_alphafold.py
79 lines (61 loc) · 2.1 KB
/
test_alphafold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import warnings
warnings.filterwarnings("ignore")
import os
import numpy
import pandas
import prody
import yabul
import proteopt
import proteopt.alphafold
DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data")
from .util import ALPHAFOLD_WEIGHTS_DIR
def test_basic():
model = proteopt.alphafold.AlphaFold(
data_dir=ALPHAFOLD_WEIGHTS_DIR,
max_length=16,
num_recycle=0,
amber_relax=False)
prediction = model.run("SIINFEKL")
print(prediction)
assert prediction.ca.getSequence() == "SIINFEKL"
assert (prediction.getCoords()**2).sum() > 0
assert prediction.getData("af2_ptm").mean() > 0
def test_prediction_with_template():
template = prody.parsePDB(os.path.join(DATA_DIR, "1MBN.pdb"))
template = template.select("chain A and resid 1 to 150 and resid != 100")
target = prody.parsePDB(os.path.join(DATA_DIR, "1MBN.pdb"))
target = target.select("protein chain A and resid 1 to 150")
sequence = target.ca.getSequence()
print("Target", len(target.ca), sequence)
print("Template", len(template.ca), template.ca.getSequence())
alignment = yabul.align_pair(
sequence,
template.ca.getSequence())
print(alignment)
print("Sequence gaps", sum(c == '-' for c in alignment.query))
print("Template gaps", sum(c == '-' for c in alignment.reference))
model_name = "model_1_ptm"
num_recycle = 0
model = proteopt.alphafold.AlphaFold(
data_dir=ALPHAFOLD_WEIGHTS_DIR,
max_length=len(target.ca),
num_recycle=num_recycle,
model_name=model_name,
amber_relax=False)
prediction = model.run(
sequence,
template=template,
template_replace_sequence_with_gaps=False,
template_mask_sidechains=False)
prody.calcTransformation(prediction.ca, target.ca).apply(prediction)
rmsd_ca = prody.calcRMSD(prediction.ca, target.ca)
ptm, = numpy.unique(prediction.getData("af2_ptm"))
print(
model_name,
"recycles=",
num_recycle,
"rmsd=",
rmsd_ca,
"ptm=",
ptm)
assert rmsd_ca < 2.0