diff --git a/src/DockQ/DockQ.py b/src/DockQ/DockQ.py index 37bb944..36b2b11 100755 --- a/src/DockQ/DockQ.py +++ b/src/DockQ/DockQ.py @@ -5,9 +5,11 @@ import hashlib import traceback import itertools +import math from functools import lru_cache, wraps from argparse import ArgumentParser from tqdm import tqdm +from parallelbar import progress_map import Bio.PDB import numpy as np @@ -49,6 +51,9 @@ def parse_args(): parser.add_argument( "--verbose", "-v", default=False, action="store_true", help="talk a lot!" ) + #parser.add_argument( + # "--progress_bar", "-v", default=False, action="store_true", help="talk a lot!" + #) parser.add_argument( "--use_CA", "-ca", @@ -62,6 +67,21 @@ def parse_args(): action="store_true", help="Do not align native and model using sequence alignments, but use the numbering of residues instead", ) + parser.add_argument( + "--n_cpu", + default=32, + type=int, + metavar="n_cpu", + help="Number of cores to use", + ) + parser.add_argument( + "--chunk_size", + default=64, + type=int, + metavar="chunk_size", + help="Size of chunks given to the cores", + ) + parser.add_argument( "--optDockQF1", default=False, @@ -570,6 +590,7 @@ def calc_DockQ( alignments, use_CA_only=False, capri_peptide=False, + low_memory=False, ): atom_for_sup = ("CA", "C", "N", "O", "P") if not use_CA_only else ("CA", "P") fnat_threshold = 4.0 if capri_peptide else 5.0 @@ -670,11 +691,13 @@ def calc_DockQ( ) # using the private _rms function which does not superimpose info = {} - + info["DockQ_F1"] = dockq_formula( f1(nat_correct, nonnat_count, nat_total), irms, Lrms ) info["DockQ"] = dockq_formula(fnat, irms, Lrms) + if low_memory: + return info info["irms"] = irms info["Lrms"] = Lrms info["fnat"] = fnat @@ -924,6 +947,7 @@ def run_on_chains( no_needle=False, use_CA_only=False, capri_peptide=False, + low_memory=False, ): # realign each model chain against the corresponding native chain alignments = [] @@ -942,6 +966,7 @@ def run_on_chains( alignments=tuple(alignments), use_CA_only=use_CA_only, capri_peptide=capri_peptide, + low_memory=False, ) return info @@ -953,6 +978,7 @@ def run_on_all_native_interfaces( no_needle=False, use_CA_only=False, capri_peptide=False, + low_memory=False, ): """Given a native-model chain map, finds all non-null native interfaces and runs DockQ for each native-model pair of interfaces""" @@ -976,6 +1002,7 @@ def run_on_all_native_interfaces( no_needle=no_needle, use_CA_only=use_CA_only, capri_peptide=capri_peptide, + low_memory=False, ) if info: info["chain1"], info["chain2"] = ( @@ -1080,7 +1107,10 @@ def format_mapping(mapping_str): def format_mapping_string(chain_mapping): chain1 = "" chain2 = "" - mapping = sorted([(b, a) for a, b in chain_mapping.items()]) + + #mapping = sorted([(b, a) for a, b in chain_mapping.items()]) + #Sorting might change LRMSD since the definition of receptor/ligand for equal length depends on order + mapping = [(b, a) for a, b in chain_mapping.items()] for ( model_chain, native_chain, @@ -1090,7 +1120,6 @@ def format_mapping_string(chain_mapping): return f"{chain1}:{chain2}" -import math def product_without_dupl(*args, repeat=1): pools = [tuple(pool) for pool in args] * repeat result = [[]] @@ -1099,24 +1128,55 @@ def product_without_dupl(*args, repeat=1): #result = set(list(map(lambda x: tuple(sorted(x)), result))) # to remove symmetric duplicates for prod in result: yield tuple(prod) -def chain_combinations(chain_clusters): +def count_chain_combinations(chain_clusters): counts={} for chain in chain_clusters: chains=tuple(chain_clusters[chain]) if chains not in counts: counts[chains]=0 counts[chains]+=1 + number_of_combinations=np.prod([math.factorial(a) for a in counts.values()]) + #combos=itertools.product(*[itertools.permutations(chains) for chains in set([tuple(ch) for ch in chain_clusters.values()])]) + return(number_of_combinations,counts) #set(chain_clusters.values()) - - combos=itertools.product(*[itertools.permutations(chains) for chains in set([tuple(ch) for ch in chain_clusters.values()])]) - - print(len(list(combos))) - sys.exit() - number_of_combinations=np.prod([math.factorial(a) for a in counts.values()]) - print(number_of_combinations) +def get_all_mappings( + model_structure, native_structure, model_chains, native_chains,initial_mapping,allowed_mismatches=0 +): + model_chains_to_combo = [mc for mc in model_chains if mc not in initial_mapping.values()] + native_chains_to_combo = [nc for nc in native_chains if nc not in initial_mapping.keys()] + + chain_clusters, reverse_map = group_chains( + model_structure, + native_structure, + model_chains_to_combo, + native_chains_to_combo, + allowed_mismatches, + ) + + all_mappings = product_without_dupl( + *[cluster for cluster in chain_clusters.values() if cluster] + ) + chain_maps=[] + for mapping in all_mappings: + chain_map = {key:value for key, value in initial_mapping.items()} + if reverse_map: + chain_map.update({ + mapping[i]: model_chain for i, model_chain in enumerate(model_chains_to_combo) + }) + else: + chain_map.update({ + native_chain: mapping[i] for i, native_chain in enumerate(native_chains_to_combo) + }) + chain_maps.append(chain_map) + return chain_maps + +def run_on_all_native_interfaces_multi(args): + return run_on_all_native_interfaces(*args) +def func(args): + print(args) #@profile def main(): args = parse_args() @@ -1124,7 +1184,7 @@ def main(): model_structure = load_PDB(args.model, chains=model_chains) native_structure = load_PDB(args.native, chains=native_chains) - + #check user-given chains are in the structures model_chains = [c.id for c in model_structure] if not model_chains else model_chains native_chains = ( [c.id for c in native_structure] if not native_chains else native_chains @@ -1138,68 +1198,58 @@ def main(): best_dockq = -1 best_result = None best_mapping = None - model_chains_to_combo = [mc for mc in model_chains if mc not in initial_mapping.values()] - native_chains_to_combo = [nc for nc in native_chains if nc not in initial_mapping.keys()] - chain_clusters, reverse_map = group_chains( + + # model_chains_to_combo = [mc for mc in model_chains if mc not in initial_mapping.values()] + # native_chains_to_combo = [nc for nc in native_chains if nc not in initial_mapping.keys()] + + # chain_clusters, reverse_map = group_chains( + # model_structure, + # native_structure, + # model_chains_to_combo, + # native_chains_to_combo, + # args.allowed_mismatches, + # ) + + + #all_mappings = product_without_dupl( + # *[cluster for cluster in chain_clusters.values() if cluster] + #) + chain_maps=get_all_mappings( model_structure, native_structure, - model_chains_to_combo, - native_chains_to_combo, + model_chains, + native_chains, + initial_mapping, args.allowed_mismatches, ) - - all_mappings = product_without_dupl( - *[cluster for cluster in chain_clusters.values() if cluster] - ) -# else: -# all_mappings = itertools.product( -# *[itertools.permutations(chains) for chains in set([tuple(ch) for ch in chain_clusters.values()])]) - - #print(len(list(all_mappings))) - - #sys.exit() - #print(len(list(all_mappings))) - #print([cluster for cluster in chain_clusters.values() if cluster]) - - # remove mappings where the same model chain is present more than once - # only if the mapping is supposed to be 1-1 - #if len(model_chains) == len(native_chains): - # all_mappings = [ - # element for element in all_mappings if len(set(element)) == len(element) - # ] - def progressbar(l): - return tqdm(l,desc='Chain combinations') if len(l) > 1 else l - - for mapping in progressbar(list(all_mappings)): - chain_map = {key:value for key, value in initial_mapping.items()} - if reverse_map: - chain_map.update({ - mapping[i]: model_chain for i, model_chain in enumerate(model_chains_to_combo) - }) - else: - chain_map.update({ - native_chain: mapping[i] for i, native_chain in enumerate(native_chains_to_combo) - }) - result_this_mapping = run_on_all_native_interfaces( - model_structure, - native_structure, - chain_map=chain_map, - no_needle=args.no_needle, - use_CA_only=args.use_CA, - capri_peptide=args.capri_peptide, - ) + low_memory=len(chain_maps) > 100 + chain_map_args=[(model_structure,native_structure,chain_map,args.no_needle,args.use_CA,args.capri_peptide,low_memory) for chain_map in chain_maps] + if len(chain_maps)>1: + chunk_size=len(chain_maps) // args.n_cpu + result_this_mappings=progress_map(run_on_all_native_interfaces_multi,chain_map_args, n_cpu=args.n_cpu, chunk_size=chunk_size) + else: #skip multi-threading for single jobs (skip the bar basically) + result_this_mappings=[run_on_all_native_interfaces(*chain_map_arg) for chain_map_arg in chain_map_args] + for chain_map,result_this_mapping in zip(chain_maps,result_this_mappings): total_dockq = sum( [result["DockQ_F1" if args.optDockQF1 else "DockQ"] for result in result_this_mapping.values()] ) - if total_dockq > best_dockq: best_dockq = total_dockq best_result = result_this_mapping best_mapping = chain_map - + if low_memory: #retrieve the full output by reruning the best chain mapping + best_result=run_on_all_native_interfaces( + model_structure, + native_structure, + best_mapping, + args.no_needle, + args.use_CA, + args.capri_peptide, + low_memory=False) + info["model"] = args.model info["native"] = args.native info["best_dockq"] = best_dockq @@ -1216,7 +1266,7 @@ def print_results(info, short=False, verbose=False, capri_peptide=False): print( f"Total DockQ over {len(info['best_result'])} native interfaces: {info['GlobalDockQ']:.3f} with {info['best_mapping_str']} model:native mapping" ) - print(info["best_result"]) + # print(info["best_result"]) for chains, results in info["best_result"].items(): print( f"DockQ{capri_peptide_str} {results['DockQ']:.3f} DockQ_F1 {results['DockQ_F1']:.3f} Fnat {results['fnat']:.3f} iRMS {results['irms']:.3f} LRMS {results['Lrms']:.3f} Fnonnat {results['fnonnat']:.3f} clashes {results['clashes']} mapping {results['chain1']}{results['chain2']}:{chains[0]}{chains[1]} {info['model']} {results['chain1']} {results['chain2']} -> {info['native']} {chains[0]} {chains[1]}"