Skip to content

Commit

Permalink
add set_workdir to pymol plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
JinyuanSun committed Jun 11, 2023
1 parent c97ecfc commit 71076bc
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 56 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ cloudmol/__pycache__

*dist*

*build*
*build*

test.ipynb
4 changes: 2 additions & 2 deletions cloudmol/cloudmol.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class PymolFold():
def __init__(self, base_url: str = "http://region-8.seetacloud.com:42711/", abs_path: str = "PymolFold_workdir", verbose: bool = True):
self.BASE_URL = base_url
self.ABS_PATH = os.path.join(os.path.expanduser("~"), abs_path)
print(f"Results will be saved to {self.ABS_PATH}")
print(f"Results will be saved to {self.ABS_PATH} by default")
if not os.path.exists(self.ABS_PATH):
os.makedirs(self.ABS_PATH)
self.verbose = verbose
Expand All @@ -36,7 +36,7 @@ def set_base_url(self, url):

def set_path(self, path):
self.ABS_PATH = path

print(f"Results will be saved to {self.ABS_PATH}")

def query_pymolfold(self, sequence: str, num_recycle: int = 3, name: str = None):
num_recycle = int(num_recycle)
Expand Down
44 changes: 44 additions & 0 deletions cloudmol/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

def plot_ca_plddt(pdb_file, size=(5,3), dpi=120):
plddts = []
with open(pdb_file, "r") as f:
lines = f.readlines()
for line in lines:
if " CA " in line:
plddt = float(line[60:66])
plddts.append(plddt)
if max(plddts) <= 1.0:
y = np.array([plddt * 100 for plddt in plddts])
print("Guessing the scale is [0,1], we scale it to [0, 100]")
else:
y = np.array(plddts)
x = np.arange(len(y)) + 1

# Create color array based on conditions
colors = np.where(y > 90, 'blue',
np.where((y > 70) & (y <= 90), 'lightblue',
np.where((y > 50) & (y <= 70), 'yellow', 'orange')))

plt.figure(figsize=size, dpi=dpi)

# Create scatter plot with colored markers
plt.plot(x, y, color='black')
plt.scatter(x, y, color=colors, zorder=10, edgecolors='black')

plt.ylim(0, 100) # Make sure y axis is in range 0-100
plt.xlabel('Residue')
plt.ylabel('pLDDT')
plt.title('Predicted LDDT per residue')

# Create legend
legend_elements = [mpatches.Patch(color='blue', label='Very high'),
mpatches.Patch(color='lightblue', label='Confident'),
mpatches.Patch(color='yellow', label='Low'),
mpatches.Patch(color='orange', label='Very low')]
plt.legend(handles=legend_elements, title='Confidence', loc='upper left', bbox_to_anchor=(1, 1))

plt.tight_layout() # Make sure nothing gets cropped off
plt.show()
128 changes: 81 additions & 47 deletions cloudmol_demo.ipynb

Large diffs are not rendered by default.

18 changes: 12 additions & 6 deletions pf_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@
import json

BASE_URL = "http://region-8.seetacloud.com:42711/"
ESMFOLD_API = "https://api.esmatlas.com/foldSequence/v1/pdb/"
ABS_PATH = os.path.abspath("./")

def set_workdir(path):
global ABS_PATH
ABS_PATH = path
if ABS_PATH[0] == "~":
ABS_PATH = os.path.join(os.path.expanduser("~"), ABS_PATH[2:])
print(f"Results will be saved to {ABS_PATH}")

def set_base_url(url):
global BASE_URL
BASE_URL = url
Expand Down Expand Up @@ -47,7 +55,7 @@ def cal_plddt(pdb_string: str):
return sum(plddts) / len(plddts)


def query_pymolfold(sequence: str, num_recycle: int = 3, name: str = None):
def query_pymolfold(sequence: str, name: str = None, num_recycle: int = 3):
num_recycle = int(num_recycle)
data = {
'sequence': sequence,
Expand Down Expand Up @@ -93,9 +101,7 @@ def query_esmfold(sequence: str, name: str = None):
"Content-Type": "application/x-www-form-urlencoded",
}

response = requests.post(
"https://api.esmatlas.com/foldSequence/v1/pdb/", headers=headers, data=sequence
)
response = requests.post(ESMFOLD_API, headers=headers, data=sequence)
if not name:
name = sequence[:3] + sequence[-3:]
pdb_filename = os.path.join(ABS_PATH, name) + ".pdb"
Expand Down Expand Up @@ -142,7 +148,6 @@ def query_mpnn(path_to_pdb: str, fix_pos=None, chain=None, rm_aa=None, inverse=F

response = requests.post(
f"{BASE_URL}mpnn/", headers=headers, files=files, params=params)
# print(response.content.decode("utf-8"))
res = response.content.decode("utf-8")

d = json.loads(res)
Expand Down Expand Up @@ -214,7 +219,7 @@ def query_dms(path_to_pdb: str):
ofile.write('mutation,002,010,020,030,ensemble\n')
for name, s1, s2, s3, s4, s5 in zip(d['mutation'], d['002'], d['010'], d['020'], d['030'], d['ensemble']):
ofile.write(f'{name},{s1},{s2},{s3},{s4},{s5}\n')
p = os.path.join(os.getcwd(), 'dms_results.csv')
p = os.path.join(ABS_PATH, 'dms_results.csv')
print(f"Results save to '{p}'")


Expand Down Expand Up @@ -317,3 +322,4 @@ def dms(selection, name='./target_bb.pdb'):
cmd.extend("singlemut", singlemut)
cmd.extend("dms", dms)
cmd.extend("ls_fix", ls_fix)
cmd.extend("set_workdir", set_workdir)

0 comments on commit 71076bc

Please sign in to comment.