From c81209d392d051272b031d6a356c3eae5b7c797d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Wallner?= Date: Mon, 25 Mar 2024 12:49:04 +0100 Subject: [PATCH] changed argument lists to argument iterators --- src/DockQ/DockQ.py | 101 +++++++++++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 44 deletions(-) diff --git a/src/DockQ/DockQ.py b/src/DockQ/DockQ.py index 9278f78..e1f9c7b 100755 --- a/src/DockQ/DockQ.py +++ b/src/DockQ/DockQ.py @@ -6,7 +6,7 @@ import traceback import itertools import math -from functools import lru_cache, wraps +from functools import lru_cache, wraps, partial from argparse import ArgumentParser from tqdm import tqdm from parallelbar import progress_map @@ -761,7 +761,7 @@ def run_on_all_native_interfaces( capri_peptide=capri_peptide, low_memory=False, ) - if info: + if info and not low_memory: info["chain1"], info["chain2"] = ( chain_map[chain_pair[0]], chain_map[chain_pair[1]], @@ -885,6 +885,7 @@ def product_without_dupl(*args, repeat=1): #result = set(list(map(lambda x: tuple(sorted(x)), result))) # to remove symmetric duplicates for prod in result: yield tuple(prod) + def count_chain_combinations(chain_clusters): counts={} for chain in chain_clusters: @@ -893,29 +894,18 @@ def count_chain_combinations(chain_clusters): counts[chains]=0 counts[chains]+=1 number_of_combinations=np.prod([math.factorial(a) for a in counts.values()]) + return number_of_combinations #combos=itertools.product(*[itertools.permutations(chains) for chains in set([tuple(ch) for ch in chain_clusters.values()])]) - return(number_of_combinations,counts) + + #return(number_of_combinations,counts) #set(chain_clusters.values()) -def get_all_mappings( - model_structure, native_structure, model_chains, native_chains,initial_mapping,allowed_mismatches=0 -): - model_chains_to_combo = [mc for mc in model_chains if mc not in initial_mapping.values()] - native_chains_to_combo = [nc for nc in native_chains if nc not in initial_mapping.keys()] - - chain_clusters, reverse_map = group_chains( - model_structure, - native_structure, - model_chains_to_combo, - native_chains_to_combo, - allowed_mismatches, - ) +def get_all_chain_maps(chain_clusters,initial_mapping,reverse_map,model_chains_to_combo,native_chains_to_combo): all_mappings = product_without_dupl( *[cluster for cluster in chain_clusters.values() if cluster] ) - chain_maps=[] for mapping in all_mappings: chain_map = {key:value for key, value in initial_mapping.items()} if reverse_map: @@ -926,11 +916,7 @@ def get_all_mappings( chain_map.update({ native_chain: mapping[i] for i, native_chain in enumerate(native_chains_to_combo) }) - chain_maps.append(chain_map) - return chain_maps - -def run_on_all_native_interfaces_multi(args): - return run_on_all_native_interfaces(*args) + yield(chain_map) #@profile @@ -955,35 +941,62 @@ def main(): best_result = None best_mapping = None + model_chains_to_combo = [mc for mc in model_chains if mc not in initial_mapping.values()] + native_chains_to_combo = [nc for nc in native_chains if nc not in initial_mapping.keys()] - - - - chain_maps=get_all_mappings( + chain_clusters, reverse_map = group_chains( model_structure, native_structure, - model_chains, - native_chains, - initial_mapping, + model_chains_to_combo, + native_chains_to_combo, args.allowed_mismatches, ) + num_chain_combinations=count_chain_combinations(chain_clusters) + chain_maps=get_all_chain_maps(chain_clusters,initial_mapping,reverse_map,model_chains_to_combo,native_chains_to_combo) + + low_memory=num_chain_combinations>100 + run_chain_map=partial(run_on_all_native_interfaces, + model_structure, + native_structure, + no_align=args.no_align, + use_CA_only=args.use_CA, + capri_peptide=args.capri_peptide, + low_memory=low_memory) ##args: chain_map + + if num_chain_combinations>1: + #chunk_size=max(1,num_chain_combinations // args.n_cpu) + #I suspect large chunk_size will result in large input arguments to the workers. + chuck_size=128 + + #for large num_chain_combinations it should be possible to divide the chain_maps in chunks + result_this_mappings=progress_map(run_chain_map,chain_maps, total=num_chain_combinations,n_cpu=args.n_cpu, chunk_size=chunk_size) + #get a fresh iterator + chain_maps=get_all_chain_maps(chain_clusters,initial_mapping,reverse_map,model_chains_to_combo,native_chains_to_combo) + for chain_map,result_this_mapping in zip(chain_maps,result_this_mappings): + total_dockq = sum( + [result["DockQ_F1" if args.optDockQF1 else "DockQ"] for result in result_this_mapping.values()] + ) + if total_dockq > best_dockq: + best_dockq = total_dockq + best_result = result_this_mapping + best_mapping = chain_map + + - low_memory=len(chain_maps) > 100 - chain_map_args=[(model_structure,native_structure,chain_map,args.no_align,args.use_CA,args.capri_peptide,low_memory) for chain_map in chain_maps] - if len(chain_maps)>1: - chunk_size=max(1,len(chain_maps) // args.n_cpu) - result_this_mappings=progress_map(run_on_all_native_interfaces_multi,chain_map_args, n_cpu=args.n_cpu, chunk_size=chunk_size) else: #skip multi-threading for single jobs (skip the bar basically) - result_this_mappings=[run_on_all_native_interfaces(*chain_map_arg) for chain_map_arg in chain_map_args] - for chain_map,result_this_mapping in zip(chain_maps,result_this_mappings): - total_dockq = sum( - [result["DockQ_F1" if args.optDockQF1 else "DockQ"] for result in result_this_mapping.values()] - ) - if total_dockq > best_dockq: - best_dockq = total_dockq - best_result = result_this_mapping - best_mapping = chain_map + # result_this_mappings=[run_chain_map(chain_map) for chain_map in chain_maps] + for chain_maps in chain_maps: + result_this_mapping=run_chain_map(chain_map) + total_dockq = sum( + [result["DockQ_F1" if args.optDockQF1 else "DockQ"] for result in result_this_mapping.values()] + ) + if total_dockq > best_dockq: + best_dockq = total_dockq + best_result = result_this_mapping + best_mapping = chain_map + + if low_memory: #retrieve the full output by reruning the best chain mapping best_result=run_on_all_native_interfaces( model_structure, @@ -1010,7 +1023,7 @@ def print_results(info, short=False, verbose=False, capri_peptide=False): print( f"Total DockQ over {len(info['best_result'])} native interfaces: {info['GlobalDockQ']:.3f} with {info['best_mapping_str']} model:native mapping" ) - # print(info["best_result"]) + print(info["best_result"]) for chains, results in info["best_result"].items(): print( f"DockQ{capri_peptide_str} {results['DockQ']:.3f} DockQ_F1 {results['DockQ_F1']:.3f} Fnat {results['fnat']:.3f} iRMS {results['irms']:.3f} LRMS {results['Lrms']:.3f} Fnonnat {results['fnonnat']:.3f} clashes {results['clashes']} mapping {results['chain1']}{results['chain2']}:{chains[0]}{chains[1]} {info['model']} {results['chain1']} {results['chain2']} -> {info['native']} {chains[0]} {chains[1]}"