Skip to content

Commit

Permalink
Added multi-threading support
Browse files Browse the repository at this point in the history
  • Loading branch information
bjornwallner committed Mar 23, 2024
1 parent d5d007b commit 796b927
Showing 1 changed file with 111 additions and 61 deletions.
172 changes: 111 additions & 61 deletions src/DockQ/DockQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import hashlib
import traceback
import itertools
import math
from functools import lru_cache, wraps
from argparse import ArgumentParser
from tqdm import tqdm
from parallelbar import progress_map

import Bio.PDB
import numpy as np
Expand Down Expand Up @@ -49,6 +51,9 @@ def parse_args():
parser.add_argument(
"--verbose", "-v", default=False, action="store_true", help="talk a lot!"
)
#parser.add_argument(
# "--progress_bar", "-v", default=False, action="store_true", help="talk a lot!"
#)
parser.add_argument(
"--use_CA",
"-ca",
Expand All @@ -62,6 +67,21 @@ def parse_args():
action="store_true",
help="Do not align native and model using sequence alignments, but use the numbering of residues instead",
)
parser.add_argument(
"--n_cpu",
default=32,
type=int,
metavar="n_cpu",
help="Number of cores to use",
)
parser.add_argument(
"--chunk_size",
default=64,
type=int,
metavar="chunk_size",
help="Size of chunks given to the cores",
)

parser.add_argument(
"--optDockQF1",
default=False,
Expand Down Expand Up @@ -570,6 +590,7 @@ def calc_DockQ(
alignments,
use_CA_only=False,
capri_peptide=False,
low_memory=False,
):
atom_for_sup = ("CA", "C", "N", "O", "P") if not use_CA_only else ("CA", "P")
fnat_threshold = 4.0 if capri_peptide else 5.0
Expand Down Expand Up @@ -670,11 +691,13 @@ def calc_DockQ(
) # using the private _rms function which does not superimpose

info = {}

info["DockQ_F1"] = dockq_formula(
f1(nat_correct, nonnat_count, nat_total), irms, Lrms
)
info["DockQ"] = dockq_formula(fnat, irms, Lrms)
if low_memory:
return info
info["irms"] = irms
info["Lrms"] = Lrms
info["fnat"] = fnat
Expand Down Expand Up @@ -924,6 +947,7 @@ def run_on_chains(
no_needle=False,
use_CA_only=False,
capri_peptide=False,
low_memory=False,
):
# realign each model chain against the corresponding native chain
alignments = []
Expand All @@ -942,6 +966,7 @@ def run_on_chains(
alignments=tuple(alignments),
use_CA_only=use_CA_only,
capri_peptide=capri_peptide,
low_memory=False,
)
return info

Expand All @@ -953,6 +978,7 @@ def run_on_all_native_interfaces(
no_needle=False,
use_CA_only=False,
capri_peptide=False,
low_memory=False,
):
"""Given a native-model chain map, finds all non-null native interfaces
and runs DockQ for each native-model pair of interfaces"""
Expand All @@ -976,6 +1002,7 @@ def run_on_all_native_interfaces(
no_needle=no_needle,
use_CA_only=use_CA_only,
capri_peptide=capri_peptide,
low_memory=False,
)
if info:
info["chain1"], info["chain2"] = (
Expand Down Expand Up @@ -1080,7 +1107,10 @@ def format_mapping(mapping_str):
def format_mapping_string(chain_mapping):
chain1 = ""
chain2 = ""
mapping = sorted([(b, a) for a, b in chain_mapping.items()])

#mapping = sorted([(b, a) for a, b in chain_mapping.items()])
#Sorting might change LRMSD since the definition of receptor/ligand for equal length depends on order
mapping = [(b, a) for a, b in chain_mapping.items()]
for (
model_chain,
native_chain,
Expand All @@ -1090,7 +1120,6 @@ def format_mapping_string(chain_mapping):

return f"{chain1}:{chain2}"

import math
def product_without_dupl(*args, repeat=1):
pools = [tuple(pool) for pool in args] * repeat
result = [[]]
Expand All @@ -1099,32 +1128,63 @@ def product_without_dupl(*args, repeat=1):
#result = set(list(map(lambda x: tuple(sorted(x)), result))) # to remove symmetric duplicates
for prod in result:
yield tuple(prod)
def chain_combinations(chain_clusters):
def count_chain_combinations(chain_clusters):
counts={}
for chain in chain_clusters:
chains=tuple(chain_clusters[chain])
if chains not in counts:
counts[chains]=0
counts[chains]+=1
number_of_combinations=np.prod([math.factorial(a) for a in counts.values()])
#combos=itertools.product(*[itertools.permutations(chains) for chains in set([tuple(ch) for ch in chain_clusters.values()])])
return(number_of_combinations,counts)
#set(chain_clusters.values())

combos=itertools.product(*[itertools.permutations(chains) for chains in set([tuple(ch) for ch in chain_clusters.values()])])


print(len(list(combos)))
sys.exit()
number_of_combinations=np.prod([math.factorial(a) for a in counts.values()])
print(number_of_combinations)

def get_all_mappings(
model_structure, native_structure, model_chains, native_chains,initial_mapping,allowed_mismatches=0
):
model_chains_to_combo = [mc for mc in model_chains if mc not in initial_mapping.values()]
native_chains_to_combo = [nc for nc in native_chains if nc not in initial_mapping.keys()]

chain_clusters, reverse_map = group_chains(
model_structure,
native_structure,
model_chains_to_combo,
native_chains_to_combo,
allowed_mismatches,
)

all_mappings = product_without_dupl(
*[cluster for cluster in chain_clusters.values() if cluster]
)
chain_maps=[]
for mapping in all_mappings:
chain_map = {key:value for key, value in initial_mapping.items()}
if reverse_map:
chain_map.update({
mapping[i]: model_chain for i, model_chain in enumerate(model_chains_to_combo)
})
else:
chain_map.update({
native_chain: mapping[i] for i, native_chain in enumerate(native_chains_to_combo)
})
chain_maps.append(chain_map)
return chain_maps

def run_on_all_native_interfaces_multi(args):
return run_on_all_native_interfaces(*args)

def func(args):
print(args)
#@profile
def main():
args = parse_args()
initial_mapping, model_chains, native_chains = format_mapping(args.mapping)
model_structure = load_PDB(args.model, chains=model_chains)
native_structure = load_PDB(args.native, chains=native_chains)


#check user-given chains are in the structures
model_chains = [c.id for c in model_structure] if not model_chains else model_chains
native_chains = (
[c.id for c in native_structure] if not native_chains else native_chains
Expand All @@ -1138,68 +1198,58 @@ def main():
best_dockq = -1
best_result = None
best_mapping = None
model_chains_to_combo = [mc for mc in model_chains if mc not in initial_mapping.values()]
native_chains_to_combo = [nc for nc in native_chains if nc not in initial_mapping.keys()]

chain_clusters, reverse_map = group_chains(

# model_chains_to_combo = [mc for mc in model_chains if mc not in initial_mapping.values()]
# native_chains_to_combo = [nc for nc in native_chains if nc not in initial_mapping.keys()]

# chain_clusters, reverse_map = group_chains(
# model_structure,
# native_structure,
# model_chains_to_combo,
# native_chains_to_combo,
# args.allowed_mismatches,
# )


#all_mappings = product_without_dupl(
# *[cluster for cluster in chain_clusters.values() if cluster]
#)
chain_maps=get_all_mappings(
model_structure,
native_structure,
model_chains_to_combo,
native_chains_to_combo,
model_chains,
native_chains,
initial_mapping,
args.allowed_mismatches,
)


all_mappings = product_without_dupl(
*[cluster for cluster in chain_clusters.values() if cluster]
)
# else:
# all_mappings = itertools.product(
# *[itertools.permutations(chains) for chains in set([tuple(ch) for ch in chain_clusters.values()])])

#print(len(list(all_mappings)))

#sys.exit()
#print(len(list(all_mappings)))
#print([cluster for cluster in chain_clusters.values() if cluster])

# remove mappings where the same model chain is present more than once
# only if the mapping is supposed to be 1-1
#if len(model_chains) == len(native_chains):
# all_mappings = [
# element for element in all_mappings if len(set(element)) == len(element)
# ]
def progressbar(l):
return tqdm(l,desc='Chain combinations') if len(l) > 1 else l

for mapping in progressbar(list(all_mappings)):
chain_map = {key:value for key, value in initial_mapping.items()}
if reverse_map:
chain_map.update({
mapping[i]: model_chain for i, model_chain in enumerate(model_chains_to_combo)
})
else:
chain_map.update({
native_chain: mapping[i] for i, native_chain in enumerate(native_chains_to_combo)
})
result_this_mapping = run_on_all_native_interfaces(
model_structure,
native_structure,
chain_map=chain_map,
no_needle=args.no_needle,
use_CA_only=args.use_CA,
capri_peptide=args.capri_peptide,
)

low_memory=len(chain_maps) > 100
chain_map_args=[(model_structure,native_structure,chain_map,args.no_needle,args.use_CA,args.capri_peptide,low_memory) for chain_map in chain_maps]
if len(chain_maps)>1:
chunk_size=len(chain_maps) // args.n_cpu
result_this_mappings=progress_map(run_on_all_native_interfaces_multi,chain_map_args, n_cpu=args.n_cpu, chunk_size=chunk_size)
else: #skip multi-threading for single jobs (skip the bar basically)
result_this_mappings=[run_on_all_native_interfaces(*chain_map_arg) for chain_map_arg in chain_map_args]
for chain_map,result_this_mapping in zip(chain_maps,result_this_mappings):
total_dockq = sum(
[result["DockQ_F1" if args.optDockQF1 else "DockQ"] for result in result_this_mapping.values()]
)

if total_dockq > best_dockq:
best_dockq = total_dockq
best_result = result_this_mapping
best_mapping = chain_map

if low_memory: #retrieve the full output by reruning the best chain mapping
best_result=run_on_all_native_interfaces(
model_structure,
native_structure,
best_mapping,
args.no_needle,
args.use_CA,
args.capri_peptide,
low_memory=False)

info["model"] = args.model
info["native"] = args.native
info["best_dockq"] = best_dockq
Expand All @@ -1216,7 +1266,7 @@ def print_results(info, short=False, verbose=False, capri_peptide=False):
print(
f"Total DockQ over {len(info['best_result'])} native interfaces: {info['GlobalDockQ']:.3f} with {info['best_mapping_str']} model:native mapping"
)
print(info["best_result"])
# print(info["best_result"])
for chains, results in info["best_result"].items():
print(
f"DockQ{capri_peptide_str} {results['DockQ']:.3f} DockQ_F1 {results['DockQ_F1']:.3f} Fnat {results['fnat']:.3f} iRMS {results['irms']:.3f} LRMS {results['Lrms']:.3f} Fnonnat {results['fnonnat']:.3f} clashes {results['clashes']} mapping {results['chain1']}{results['chain2']}:{chains[0]}{chains[1]} {info['model']} {results['chain1']} {results['chain2']} -> {info['native']} {chains[0]} {chains[1]}"
Expand Down

0 comments on commit 796b927

Please sign in to comment.