Skip to content

Commit

Permalink
lowmem test
Browse files Browse the repository at this point in the history
  • Loading branch information
clami66 committed Apr 4, 2024
1 parent 9b6d998 commit b432622
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 47 deletions.
4 changes: 4 additions & 0 deletions run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ diff test testdata/dimer_dimer.dockq
$binary examples/model.pdb examples/native.pdb --allowed_mismatches 1 > test
diff test testdata/model.dockq

# lowmem test
$binary examples/1EXB_r_l_b.model.pdb examples/1EXB_r_l_b.pdb --short > test
diff test testdata/1EXB.dockq

# Test various mapping strategies
$binary examples/1EXB_r_l_b.model.pdb examples/1EXB_r_l_b.pdb --short --mapping AB*:BA* > test
diff test testdata/1EXB_AB.BA.dockq
Expand Down
109 changes: 62 additions & 47 deletions src/DockQ/DockQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,32 @@ def get_residue_distances(chain1, chain2, what, all_atom=True):
return model_res_distances


#@profile
# @profile
def calc_DockQ(
sample_chains,
ref_chains,
alignments,
capri_peptide=False,
low_memory=False,
):
atom_for_sup = ("CA", "C", "N", "O",
"P", "OP1", "OP2",
"O2'", "O3'", "O4'", "O5'",
"C1'", "C2'", "C3'", "C4'", "C5'")
atom_for_sup = (
"CA",
"C",
"N",
"O",
"P",
"OP1",
"OP2",
"O2'",
"O3'",
"O4'",
"O5'",
"C1'",
"C2'",
"C3'",
"C4'",
"C5'",
)
fnat_threshold = 4.0 if capri_peptide else 5.0
interface_threshold = 8.0 if capri_peptide else 10.0
clash_threshold = 2.0
Expand Down Expand Up @@ -552,16 +566,15 @@ def run_on_all_native_interfaces(
chain_map[chain_pair[0]],
chain_map[chain_pair[1]],
)
info['chain_map']=chain_map #diagonstics
info["chain_map"] = chain_map # diagonstics
result_mapping[chain_pair] = info
total_dockq = sum(
[
result["DockQ_F1" if optDockQF1 else "DockQ"]
for result in result_mapping.values()
]
)
return result_mapping,total_dockq

[
result["DockQ_F1" if optDockQF1 else "DockQ"]
for result in result_mapping.values()
]
)
return result_mapping, total_dockq


def load_PDB(path, chains=[], n_model=0):
Expand Down Expand Up @@ -600,7 +613,9 @@ def group_chains(
chain_clusters = {chain: [] for chain in ref_chains}

for query_chain, ref_chain in alignment_targets:
aln = align_chains(query_structure[query_chain], ref_structure[ref_chain], use_numbering=None)
aln = align_chains(
query_structure[query_chain], ref_structure[ref_chain], use_numbering=None
)
alignment = format_alignment(aln)
n_mismatches = alignment["matches"].count(".")

Expand Down Expand Up @@ -683,7 +698,9 @@ def product_without_dupl(*args, repeat=1):

def count_chain_combinations(chain_clusters):
clusters = [tuple(li) for li in chain_clusters.values()]
number_of_combinations = np.prod([math.factorial(a) for a in Counter(clusters).values()])
number_of_combinations = np.prod(
[math.factorial(a) for a in Counter(clusters).values()]
)
return number_of_combinations


Expand All @@ -702,7 +719,7 @@ def get_all_chain_maps(
if reverse_map:
chain_map.update(
{
mapping[i]: model_chain
mapping[i]: model_chain
for i, model_chain in enumerate(model_chains_to_combo)
}
)
Expand All @@ -715,27 +732,30 @@ def get_all_chain_maps(
)
yield (chain_map)


def get_chain_map_from_dockq(result):
chain_map={}
for ch1,ch2 in result:
chain_map[ch1]=result[ch1,ch2]['chain1']
chain_map[ch2]=result[ch1,ch2]['chain2']
chain_map = {}
for ch1, ch2 in result:
chain_map[ch1] = result[ch1, ch2]["chain1"]
chain_map[ch2] = result[ch1, ch2]["chain2"]
return chain_map
def get_best_mapping(result_mappings,optDockF1=False):
total_dockq=0


def get_best_mapping(result_mappings, optDockF1=False):
total_dockq = 0
for result_mapping in result_mappings:
total_dockq = sum(
[
result["DockQ_F1" if optDockQF1 else "DockQ"]
for result in result_mapping.values()
]
)
[
result["DockQ_F1" if optDockQF1 else "DockQ"]
for result in result_mapping.values()
]
)
if total_dockq > best_dockq:
best_dockq = total_dockq
return best_result, best_dockq
return best_result, best_dockq


#@profile
# @profile
def main():
args = parse_args()
initial_mapping, model_chains, native_chains = format_mapping(args.mapping)
Expand Down Expand Up @@ -771,7 +791,6 @@ def main():
native_chains_to_combo,
args.allowed_mismatches,
)


chain_maps = get_all_chain_maps(
chain_clusters,
Expand All @@ -781,13 +800,12 @@ def main():
native_chains_to_combo,
)
num_chain_combinations = count_chain_combinations(chain_clusters)
if num_chain_combinations==1 and not args.mapping: #A HACK count_chain_combinations does not work if there are different number of chains in native.
if (
num_chain_combinations == 1 and not args.mapping
): # A HACK count_chain_combinations does not work if there are different number of chains in native.
chain_maps, chain_maps_ = itertools.tee(chain_maps)
num_chain_combinations=sum(1 for _ in chain_maps_)
#print(num_chain_combinations,chain_clusters)
#print(list(chain_maps))
num_chain_combinations = sum(1 for _ in chain_maps_)

#sys.exit()
# copy iterator to use later
chain_maps, chain_maps2 = itertools.tee(chain_maps)

Expand All @@ -802,10 +820,9 @@ def main():
)

if num_chain_combinations > 1:
#chunk_size = args.n_chunk
cpus=min(num_chain_combinations,args.n_cpu)
chunk_size=min(args.max_chunk,max(1,num_chain_combinations//cpus))
#print(cpus,chunk_size)
cpus = min(num_chain_combinations, args.n_cpu)
chunk_size = min(args.max_chunk, max(1, num_chain_combinations // cpus))

# for large num_chain_combinations it should be possible to divide the chain_maps in chunks
result_this_mappings = progress_map(
run_chain_map,
Expand All @@ -815,19 +832,17 @@ def main():
chunk_size=chunk_size,
)

for chain_map, (result_this_mapping,total_dockq) in zip(chain_maps2, result_this_mappings):
#print(chain_map,result_this_mapping)
for chain_map, (result_this_mapping, total_dockq) in zip(
chain_maps2, result_this_mappings
):

if total_dockq > best_dockq:
best_dockq = total_dockq
best_result = result_this_mapping
#best_mapping2=get_chain_map_from_dockq(best_result)
best_mapping = chain_map
#print(best_result)
#print(f"{format_mapping_string(best_mapping)}")
#print(f"{format_mapping_string(best_mapping2)}")

if low_memory: # retrieve the full output by rerunning the best chain mapping
best_result,total_dockq = run_on_all_native_interfaces(
best_result, total_dockq = run_on_all_native_interfaces(
model_structure,
native_structure,
best_mapping,
Expand All @@ -838,7 +853,7 @@ def main():

else: # skip multi-threading for single jobs (skip the bar basically)
best_mapping = next(chain_maps)
best_result,best_dockq = run_chain_map(best_mapping)
best_result, best_dockq = run_chain_map(best_mapping)

info = dict()
info["model"] = args.model
Expand Down

0 comments on commit b432622

Please sign in to comment.