Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force ordering of atoms according to backbone list #50

Merged
merged 5 commits into from
Aug 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 104 additions & 109 deletions src/DockQ/DockQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def parse_args():
"--short", default=False, action="store_true", help="Short output"
)
parser.add_argument(
"--json", default=None, metavar="out.json", help="Write outputs to a chosen json file"
"--json",
default=None,
metavar="out.json",
help="Write outputs to a chosen json file",
)
parser.add_argument(
"--verbose", "-v", default=False, action="store_true", help="Verbose output"
Expand Down Expand Up @@ -199,12 +202,13 @@ def calc_sym_corrected_lrmsd(
aligned_sample_receptor, aligned_ref_receptor = get_aligned_residues(
sample_receptor, ref_receptor, receptor_alignment
)
# get a copy of each structure, then only keep backbone atoms
sample_interface_atoms, ref_interface_atoms = get_interface_atoms(
(receptor_interface, ()),
(aligned_sample_receptor, ()),
(aligned_ref_receptor, ()),

sample_interface_atoms, ref_interface_atoms = subset_atoms(
aligned_sample_receptor,
aligned_ref_receptor,
atom_types=BACKBONE_ATOMS,
residue_subset=receptor_interface,
what="receptor",
)

sample_ligand_atoms_ids = [atom.id for atom in sample_ligand.get_atoms()]
Expand All @@ -230,7 +234,9 @@ def calc_sym_corrected_lrmsd(

# Set to align on receptor interface
super_imposer = SVDSuperimposer()
super_imposer.set(ref_interface_atoms, sample_interface_atoms)
super_imposer.set(
np.asarray(ref_interface_atoms), np.asarray(sample_interface_atoms)
)
super_imposer.run()
rot, tran = super_imposer.get_rotran()

Expand Down Expand Up @@ -323,13 +329,25 @@ def calc_DockQ(
ref_res_distances,
threshold=interface_threshold ** 2,
)
# get a copy of each structure, then only keep backbone atoms
sample_interface_atoms, ref_interface_atoms = get_interface_atoms(
interacting_pairs,
(aligned_sample_1, aligned_sample_2),
(aligned_ref_1, aligned_ref_2),

sample_interface_atoms1, ref_interface_atoms1 = subset_atoms(
aligned_sample_1,
aligned_ref_1,
atom_types=BACKBONE_ATOMS,
residue_subset=interacting_pairs[0],
)
sample_interface_atoms2, ref_interface_atoms2 = subset_atoms(
aligned_sample_2,
aligned_ref_2,
atom_types=BACKBONE_ATOMS,
residue_subset=interacting_pairs[1],
)

sample_interface_atoms = np.asarray(
sample_interface_atoms1 + sample_interface_atoms2
)
ref_interface_atoms = np.asarray(ref_interface_atoms1 + ref_interface_atoms2)

super_imposer = SVDSuperimposer()
super_imposer.set(sample_interface_atoms, ref_interface_atoms)
super_imposer.run()
Expand All @@ -354,23 +372,26 @@ def calc_DockQ(
else ("ligand", "receptor")
)

receptor_atoms_native, receptor_atoms_sample = np.asarray(
get_atoms_per_residue(
receptor_chains, what="receptor", atom_types=BACKBONE_ATOMS
)
receptor_atoms_native, receptor_atoms_sample = subset_atoms(
receptor_chains[0],
receptor_chains[1],
atom_types=BACKBONE_ATOMS,
what="receptor",
)
ligand_atoms_native, ligand_atoms_sample = np.asarray(
get_atoms_per_residue(ligand_chains, what="ligand", atom_types=BACKBONE_ATOMS)
ligand_atoms_native, ligand_atoms_sample = subset_atoms(
ligand_chains[0], ligand_chains[1], atom_types=BACKBONE_ATOMS, what="ligand"
)
# Set to align on receptor
super_imposer.set(receptor_atoms_native, receptor_atoms_sample)
super_imposer.set(
np.asarray(receptor_atoms_native), np.asarray(receptor_atoms_sample)
)
super_imposer.run()

rot, tran = super_imposer.get_rotran()
rotated_sample_atoms = np.dot(ligand_atoms_sample, rot) + tran
rotated_sample_atoms = np.dot(np.asarray(ligand_atoms_sample), rot) + tran

lrms = super_imposer._rms(
ligand_atoms_native, rotated_sample_atoms
np.asarray(ligand_atoms_native), rotated_sample_atoms
) # using the private _rms function which does not superimpose

info = {}
Expand Down Expand Up @@ -484,86 +505,44 @@ def list_atoms_per_residue(chain, what):
return np.array(n_atoms_per_residue).astype(int)


@lru_cache
def get_atoms_per_residue(
chains,
what,
atom_types=("CA", "C", "N", "O", "P"),
):
chain1, chain2 = chains
atoms1 = [
atom.coord
for res1, res2 in zip(chain1, chain2)
for atom in res1.get_atoms()
if atom.id in atom_types and atom.id in [a.id for a in res2.get_atoms()]
]

atoms2 = [
atom.coord
for res1, res2 in zip(chain1, chain2)
for atom in res2.get_atoms()
if atom.id in atom_types and atom.id in [a.id for a in res1.get_atoms()]
]
return atoms1, atoms2


def get_interacting_pairs(distances, threshold):
interacting_pairs = np.nonzero(np.asarray(distances) < threshold)
return tuple(interacting_pairs[0]), tuple(interacting_pairs[1])


@lru_cache
def get_interface_atoms(
interacting_pairs,
model_chains,
ref_chains,
atom_types=[],
def subset_atoms(
mod_chain,
ref_chain,
atom_types,
residue_subset=None,
what="",
):
ref_interface = []
mod_interface = []
mod_atoms = []
ref_atoms = []

ref_residues_group1 = [res for res in ref_chains[0]]
ref_residues_group2 = [res for res in ref_chains[1]]
mod_residues = [res for res in mod_chain]
ref_residues = [res for res in ref_chain]

mod_residues_group1 = [res for res in model_chains[0]]
mod_residues_group2 = [res for res in model_chains[1]]
# remove duplicate residues
interface_residues_group1 = set(interacting_pairs[0])
interface_residues_group2 = set(interacting_pairs[1])

for i in interface_residues_group1:
ref_atoms = [atom for atom in ref_residues_group1[i].get_atoms()]
mod_atoms = [atom for atom in mod_residues_group1[i].get_atoms()]
ref_atoms_ids = [atom.id for atom in ref_atoms]
mod_atoms_ids = [atom.id for atom in mod_atoms]
ref_interface += [
atom.coord
for atom in ref_atoms
if atom.id in atom_types and atom.id in mod_atoms_ids
]
mod_interface += [
atom.coord
for atom in mod_atoms
if atom.id in atom_types and atom.id in ref_atoms_ids
]
residue_subset = set(residue_subset) if residue_subset else range(len(mod_residues))

for j in interface_residues_group2:
ref_atoms = [atom for atom in ref_residues_group2[j].get_atoms()]
mod_atoms = [atom for atom in mod_residues_group2[j].get_atoms()]
ref_atoms_ids = [atom.id for atom in ref_atoms]
mod_atoms_ids = [atom.id for atom in mod_atoms]
ref_interface += [
atom.coord
for atom in ref_atoms
if atom.id in atom_types and atom.id in mod_atoms_ids
]
mod_interface += [
atom.coord
for atom in mod_atoms
if atom.id in atom_types and atom.id in ref_atoms_ids
]
for i in residue_subset:
mod_res_atoms = list(mod_residues[i].get_atoms())
ref_res_atoms = list(ref_residues[i].get_atoms())
mod_res_atoms_ids = [atom.id for atom in mod_res_atoms]
ref_res_atoms_ids = [atom.id for atom in ref_res_atoms]

return np.asarray(mod_interface), np.asarray(ref_interface)
for atom_type in atom_types:
try:
mod_i = mod_res_atoms_ids.index(atom_type)
ref_i = ref_res_atoms_ids.index(atom_type)
mod_atoms += [mod_res_atoms[mod_i].coord]
ref_atoms += [ref_res_atoms[ref_i].coord]
except:
continue

return mod_atoms, ref_atoms


@lru_cache
Expand Down Expand Up @@ -662,12 +641,7 @@ def run_on_all_native_interfaces(
)
info["chain_map"] = chain_map # diagnostics
result_mapping["".join(chain_pair)] = info
total_dockq = sum(
[
result["DockQ"]
for result in result_mapping.values()
]
)
total_dockq = sum([result["DockQ"] for result in result_mapping.values()])
return result_mapping, total_dockq


Expand Down Expand Up @@ -713,21 +687,25 @@ def group_chains(
try:
qc = query_structure[query_chain]
except KeyError:
logging.error(f"""The specified model chain {query_chain} is not found in the PDB structure.
logging.error(
f"""The specified model chain {query_chain} is not found in the PDB structure.
This is possibly due to using the wrong chain identifier in --mapping,
or forgetting to specify --small_molecule if this is a HETATM chain.
If working with mmCIF files, make sure you use the right chain identifier.
""")
"""
)
print(traceback.format_exc())
sys.exit(1)
try:
rc = ref_structure[ref_chain]
except KeyError:
logging.error(f"""The specified native chain {ref_chain} is not found in the PDB structure.
logging.error(
f"""The specified native chain {ref_chain} is not found in the PDB structure.
This is possibly due to using the wrong chain identifier in --mapping,
or forgetting to specify --small_molecule if this is a HETATM chain.
If working with mmCIF files, make sure you use the right chain identifier.
""")
"""
)
het_qc = qc.is_het
het_rc = rc.is_het

Expand All @@ -753,17 +731,18 @@ def group_chains(
]

if mismatch_dict:
logging.warning(f"""Some chains have a limited number of sequence mismatches and are treated as non-homologous.
logging.warning(
f"""Some chains have a limited number of sequence mismatches and are treated as non-homologous.
Try increasing the --allowed_mismatches for the following: {", ".join(f"Model chain {c[1]}, native chain {c[0]}: {m} mismatches" for c, m in mismatch_dict.items())}
if they should be treated as homologous.""")
if they should be treated as homologous."""
)

if chains_without_match:
logging.error(
f"For chains {chains_without_match} no identical corresponding chain was found between in the native."
)
sys.exit(1)


return chain_clusters, reverse_map


Expand All @@ -776,7 +755,9 @@ def format_mapping(mapping_str, small_molecule=None):

model_mapping, native_mapping = mapping_str.split(":")
if not native_mapping:
logging.error("When using --mapping, native chains must be set (e.g. ABC:ABC or :ABC)")
logging.error(
"When using --mapping, native chains must be set (e.g. ABC:ABC or :ABC)"
)
sys.exit()
else:
# :ABC or *:ABC only use those natives chains, permute model chains
Expand Down Expand Up @@ -834,8 +815,10 @@ def count_chain_combinations(chain_clusters):
]
)
except ValueError:
logging.error("""Couldn't find a match between each model-native chain specified in the mapping.
Make sure that all chains in your model have a homologous chain in the native, or specify the right subset of chains with --mapping""")
logging.error(
"""Couldn't find a match between each model-native chain specified in the mapping.
Make sure that all chains in your model have a homologous chain in the native, or specify the right subset of chains with --mapping"""
)
sys.exit()
return number_of_combinations

Expand Down Expand Up @@ -978,7 +961,9 @@ def main():
best_result, best_dockq = run_chain_map(best_mapping)

if not best_result:
logging.error("Could not find interfaces in the native model. Please double check the inputs or select different chains with the --mapping flag.")
logging.error(
"Could not find interfaces in the native model. Please double check the inputs or select different chains with the --mapping flag."
)
sys.exit(1)

info = dict()
Expand All @@ -989,7 +974,7 @@ def main():
info["GlobalDockQ"] = best_dockq / len(best_result)
info["best_mapping"] = best_mapping
info["best_mapping_str"] = f"{format_mapping_string(best_mapping)}"

if args.json:
with open(args.json, "w") as fp:
json.dump(info, fp)
Expand Down Expand Up @@ -1030,7 +1015,12 @@ def print_results(
)
hetname = f" ({results['is_het']})" if results["is_het"] else ""
score_str = " ".join(
[f"{item} {results[item]:.3f}" if item != "clashes" else f"{item} {results[item]}" for item in reported_measures]
[
f"{item} {results[item]:.3f}"
if item != "clashes"
else f"{item} {results[item]}"
for item in reported_measures
]
)
print(
f"{score_str} mapping {results['chain1']}{results['chain2']}:{chains[0]}{chains[1]}{hetname} {info['model']} {results['chain1']} {results['chain2']} -> {info['native']} {chains[0]} {chains[1]}"
Expand Down Expand Up @@ -1061,7 +1051,12 @@ def print_results(
print(f"\tModel chains: {results['chain1']}, {results['chain2']}")
print(
"\n".join(
[f"\t{item}: {results[item]:.3f}" if item != "clashes" else f"\t{item}: {results[item]}" for item in reported_measures]
[
f"\t{item}: {results[item]:.3f}"
if item != "clashes"
else f"\t{item}: {results[item]}"
for item in reported_measures
]
)
)

Expand Down
Loading