Skip to content

Commit

Permalink
blacken
Browse files Browse the repository at this point in the history
  • Loading branch information
clami66 committed Aug 28, 2024
1 parent 0243c72 commit 9c3afe3
Showing 1 changed file with 80 additions and 33 deletions.
113 changes: 80 additions & 33 deletions src/DockQ/DockQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def parse_args():
"--short", default=False, action="store_true", help="Short output"
)
parser.add_argument(
"--json", default=None, metavar="out.json", help="Write outputs to a chosen json file"
"--json",
default=None,
metavar="out.json",
help="Write outputs to a chosen json file",
)
parser.add_argument(
"--verbose", "-v", default=False, action="store_true", help="Verbose output"
Expand Down Expand Up @@ -200,7 +203,13 @@ def calc_sym_corrected_lrmsd(
sample_receptor, ref_receptor, receptor_alignment
)

sample_interface_atoms, ref_interface_atoms = subset_atoms(aligned_sample_receptor, aligned_ref_receptor, atom_types=BACKBONE_ATOMS, residue_subset=receptor_interface)
sample_interface_atoms, ref_interface_atoms = subset_atoms(
aligned_sample_receptor,
aligned_ref_receptor,
atom_types=BACKBONE_ATOMS,
residue_subset=receptor_interface,
what="receptor",
)

sample_ligand_atoms_ids = [atom.id for atom in sample_ligand.get_atoms()]
sample_ligand_atoms_ele = [atom.element for atom in sample_ligand.get_atoms()]
Expand All @@ -225,7 +234,9 @@ def calc_sym_corrected_lrmsd(

# Set to align on receptor interface
super_imposer = SVDSuperimposer()
super_imposer.set(np.asarray(ref_interface_atoms), np.asarray(sample_interface_atoms))
super_imposer.set(
np.asarray(ref_interface_atoms), np.asarray(sample_interface_atoms)
)
super_imposer.run()
rot, tran = super_imposer.get_rotran()

Expand Down Expand Up @@ -318,11 +329,23 @@ def calc_DockQ(
ref_res_distances,
threshold=interface_threshold ** 2,
)

sample_interface_atoms1, ref_interface_atoms1 = subset_atoms(aligned_sample_1, aligned_ref_1, atom_types=BACKBONE_ATOMS, residue_subset=interacting_pairs[0])
sample_interface_atoms2, ref_interface_atoms2 = subset_atoms(aligned_sample_2, aligned_ref_2, atom_types=BACKBONE_ATOMS, residue_subset=interacting_pairs[1])

sample_interface_atoms = np.asarray(sample_interface_atoms1 + sample_interface_atoms2)

sample_interface_atoms1, ref_interface_atoms1 = subset_atoms(
aligned_sample_1,
aligned_ref_1,
atom_types=BACKBONE_ATOMS,
residue_subset=interacting_pairs[0],
)
sample_interface_atoms2, ref_interface_atoms2 = subset_atoms(
aligned_sample_2,
aligned_ref_2,
atom_types=BACKBONE_ATOMS,
residue_subset=interacting_pairs[1],
)

sample_interface_atoms = np.asarray(
sample_interface_atoms1 + sample_interface_atoms2
)
ref_interface_atoms = np.asarray(ref_interface_atoms1 + ref_interface_atoms2)

super_imposer = SVDSuperimposer()
Expand Down Expand Up @@ -350,11 +373,18 @@ def calc_DockQ(
)

receptor_atoms_native, receptor_atoms_sample = subset_atoms(
receptor_chains[0], receptor_chains[1], atom_types=BACKBONE_ATOMS
)
ligand_atoms_native, ligand_atoms_sample = subset_atoms(ligand_chains[0], ligand_chains[1], atom_types=BACKBONE_ATOMS)
receptor_chains[0],
receptor_chains[1],
atom_types=BACKBONE_ATOMS,
what="receptor",
)
ligand_atoms_native, ligand_atoms_sample = subset_atoms(
ligand_chains[0], ligand_chains[1], atom_types=BACKBONE_ATOMS, what="ligand"
)
# Set to align on receptor
super_imposer.set(np.asarray(receptor_atoms_native), np.asarray(receptor_atoms_sample))
super_imposer.set(
np.asarray(receptor_atoms_native), np.asarray(receptor_atoms_sample)
)
super_imposer.run()

rot, tran = super_imposer.get_rotran()
Expand Down Expand Up @@ -486,13 +516,14 @@ def subset_atoms(
ref_chain,
atom_types,
residue_subset=None,
what="",
):
mod_atoms = []
ref_atoms = []

mod_residues = [res for res in mod_chain]
ref_residues = [res for res in ref_chain]

# remove duplicate residues
residue_subset = set(residue_subset) if residue_subset else range(len(mod_residues))

Expand Down Expand Up @@ -610,12 +641,7 @@ def run_on_all_native_interfaces(
)
info["chain_map"] = chain_map # diagnostics
result_mapping["".join(chain_pair)] = info
total_dockq = sum(
[
result["DockQ"]
for result in result_mapping.values()
]
)
total_dockq = sum([result["DockQ"] for result in result_mapping.values()])
return result_mapping, total_dockq


Expand Down Expand Up @@ -661,21 +687,25 @@ def group_chains(
try:
qc = query_structure[query_chain]
except KeyError:
logging.error(f"""The specified model chain {query_chain} is not found in the PDB structure.
logging.error(
f"""The specified model chain {query_chain} is not found in the PDB structure.
This is possibly due to using the wrong chain identifier in --mapping,
or forgetting to specify --small_molecule if this is a HETATM chain.
If working with mmCIF files, make sure you use the right chain identifier.
""")
"""
)
print(traceback.format_exc())
sys.exit(1)
try:
rc = ref_structure[ref_chain]
except KeyError:
logging.error(f"""The specified native chain {ref_chain} is not found in the PDB structure.
logging.error(
f"""The specified native chain {ref_chain} is not found in the PDB structure.
This is possibly due to using the wrong chain identifier in --mapping,
or forgetting to specify --small_molecule if this is a HETATM chain.
If working with mmCIF files, make sure you use the right chain identifier.
""")
"""
)
het_qc = qc.is_het
het_rc = rc.is_het

Expand All @@ -701,17 +731,18 @@ def group_chains(
]

if mismatch_dict:
logging.warning(f"""Some chains have a limited number of sequence mismatches and are treated as non-homologous.
logging.warning(
f"""Some chains have a limited number of sequence mismatches and are treated as non-homologous.
Try increasing the --allowed_mismatches for the following: {", ".join(f"Model chain {c[1]}, native chain {c[0]}: {m} mismatches" for c, m in mismatch_dict.items())}
if they should be treated as homologous.""")
if they should be treated as homologous."""
)

if chains_without_match:
logging.error(
f"For chains {chains_without_match} no identical corresponding chain was found between in the native."
)
sys.exit(1)


return chain_clusters, reverse_map


Expand All @@ -724,7 +755,9 @@ def format_mapping(mapping_str, small_molecule=None):

model_mapping, native_mapping = mapping_str.split(":")
if not native_mapping:
logging.error("When using --mapping, native chains must be set (e.g. ABC:ABC or :ABC)")
logging.error(
"When using --mapping, native chains must be set (e.g. ABC:ABC or :ABC)"
)
sys.exit()
else:
# :ABC or *:ABC only use those natives chains, permute model chains
Expand Down Expand Up @@ -782,8 +815,10 @@ def count_chain_combinations(chain_clusters):
]
)
except ValueError:
logging.error("""Couldn't find a match between each model-native chain specified in the mapping.
Make sure that all chains in your model have a homologous chain in the native, or specify the right subset of chains with --mapping""")
logging.error(
"""Couldn't find a match between each model-native chain specified in the mapping.
Make sure that all chains in your model have a homologous chain in the native, or specify the right subset of chains with --mapping"""
)
sys.exit()
return number_of_combinations

Expand Down Expand Up @@ -926,7 +961,9 @@ def main():
best_result, best_dockq = run_chain_map(best_mapping)

if not best_result:
logging.error("Could not find interfaces in the native model. Please double check the inputs or select different chains with the --mapping flag.")
logging.error(
"Could not find interfaces in the native model. Please double check the inputs or select different chains with the --mapping flag."
)
sys.exit(1)

info = dict()
Expand All @@ -937,7 +974,7 @@ def main():
info["GlobalDockQ"] = best_dockq / len(best_result)
info["best_mapping"] = best_mapping
info["best_mapping_str"] = f"{format_mapping_string(best_mapping)}"

if args.json:
with open(args.json, "w") as fp:
json.dump(info, fp)
Expand Down Expand Up @@ -978,7 +1015,12 @@ def print_results(
)
hetname = f" ({results['is_het']})" if results["is_het"] else ""
score_str = " ".join(
[f"{item} {results[item]:.3f}" if item != "clashes" else f"{item} {results[item]}" for item in reported_measures]
[
f"{item} {results[item]:.3f}"
if item != "clashes"
else f"{item} {results[item]}"
for item in reported_measures
]
)
print(
f"{score_str} mapping {results['chain1']}{results['chain2']}:{chains[0]}{chains[1]}{hetname} {info['model']} {results['chain1']} {results['chain2']} -> {info['native']} {chains[0]} {chains[1]}"
Expand Down Expand Up @@ -1009,7 +1051,12 @@ def print_results(
print(f"\tModel chains: {results['chain1']}, {results['chain2']}")
print(
"\n".join(
[f"\t{item}: {results[item]:.3f}" if item != "clashes" else f"\t{item}: {results[item]}" for item in reported_measures]
[
f"\t{item}: {results[item]:.3f}"
if item != "clashes"
else f"\t{item}: {results[item]}"
for item in reported_measures
]
)
)

Expand Down

0 comments on commit 9c3afe3

Please sign in to comment.