diff --git a/src/DockQ/DockQ.py b/src/DockQ/DockQ.py index de47f54..d48f41f 100755 --- a/src/DockQ/DockQ.py +++ b/src/DockQ/DockQ.py @@ -55,7 +55,10 @@ def parse_args(): "--short", default=False, action="store_true", help="Short output" ) parser.add_argument( - "--json", default=None, metavar="out.json", help="Write outputs to a chosen json file" + "--json", + default=None, + metavar="out.json", + help="Write outputs to a chosen json file", ) parser.add_argument( "--verbose", "-v", default=False, action="store_true", help="Verbose output" @@ -199,12 +202,13 @@ def calc_sym_corrected_lrmsd( aligned_sample_receptor, aligned_ref_receptor = get_aligned_residues( sample_receptor, ref_receptor, receptor_alignment ) - # get a copy of each structure, then only keep backbone atoms - sample_interface_atoms, ref_interface_atoms = get_interface_atoms( - (receptor_interface, ()), - (aligned_sample_receptor, ()), - (aligned_ref_receptor, ()), + + sample_interface_atoms, ref_interface_atoms = subset_atoms( + aligned_sample_receptor, + aligned_ref_receptor, atom_types=BACKBONE_ATOMS, + residue_subset=receptor_interface, + what="receptor", ) sample_ligand_atoms_ids = [atom.id for atom in sample_ligand.get_atoms()] @@ -230,7 +234,9 @@ def calc_sym_corrected_lrmsd( # Set to align on receptor interface super_imposer = SVDSuperimposer() - super_imposer.set(ref_interface_atoms, sample_interface_atoms) + super_imposer.set( + np.asarray(ref_interface_atoms), np.asarray(sample_interface_atoms) + ) super_imposer.run() rot, tran = super_imposer.get_rotran() @@ -323,13 +329,25 @@ def calc_DockQ( ref_res_distances, threshold=interface_threshold ** 2, ) - # get a copy of each structure, then only keep backbone atoms - sample_interface_atoms, ref_interface_atoms = get_interface_atoms( - interacting_pairs, - (aligned_sample_1, aligned_sample_2), - (aligned_ref_1, aligned_ref_2), + + sample_interface_atoms1, ref_interface_atoms1 = subset_atoms( + aligned_sample_1, + aligned_ref_1, + atom_types=BACKBONE_ATOMS, + residue_subset=interacting_pairs[0], + ) + sample_interface_atoms2, ref_interface_atoms2 = subset_atoms( + aligned_sample_2, + aligned_ref_2, atom_types=BACKBONE_ATOMS, + residue_subset=interacting_pairs[1], ) + + sample_interface_atoms = np.asarray( + sample_interface_atoms1 + sample_interface_atoms2 + ) + ref_interface_atoms = np.asarray(ref_interface_atoms1 + ref_interface_atoms2) + super_imposer = SVDSuperimposer() super_imposer.set(sample_interface_atoms, ref_interface_atoms) super_imposer.run() @@ -354,23 +372,26 @@ def calc_DockQ( else ("ligand", "receptor") ) - receptor_atoms_native, receptor_atoms_sample = np.asarray( - get_atoms_per_residue( - receptor_chains, what="receptor", atom_types=BACKBONE_ATOMS - ) + receptor_atoms_native, receptor_atoms_sample = subset_atoms( + receptor_chains[0], + receptor_chains[1], + atom_types=BACKBONE_ATOMS, + what="receptor", ) - ligand_atoms_native, ligand_atoms_sample = np.asarray( - get_atoms_per_residue(ligand_chains, what="ligand", atom_types=BACKBONE_ATOMS) + ligand_atoms_native, ligand_atoms_sample = subset_atoms( + ligand_chains[0], ligand_chains[1], atom_types=BACKBONE_ATOMS, what="ligand" ) # Set to align on receptor - super_imposer.set(receptor_atoms_native, receptor_atoms_sample) + super_imposer.set( + np.asarray(receptor_atoms_native), np.asarray(receptor_atoms_sample) + ) super_imposer.run() rot, tran = super_imposer.get_rotran() - rotated_sample_atoms = np.dot(ligand_atoms_sample, rot) + tran + rotated_sample_atoms = np.dot(np.asarray(ligand_atoms_sample), rot) + tran lrms = super_imposer._rms( - ligand_atoms_native, rotated_sample_atoms + np.asarray(ligand_atoms_native), rotated_sample_atoms ) # using the private _rms function which does not superimpose info = {} @@ -484,86 +505,44 @@ def list_atoms_per_residue(chain, what): return np.array(n_atoms_per_residue).astype(int) -@lru_cache -def get_atoms_per_residue( - chains, - what, - atom_types=("CA", "C", "N", "O", "P"), -): - chain1, chain2 = chains - atoms1 = [ - atom.coord - for res1, res2 in zip(chain1, chain2) - for atom in res1.get_atoms() - if atom.id in atom_types and atom.id in [a.id for a in res2.get_atoms()] - ] - - atoms2 = [ - atom.coord - for res1, res2 in zip(chain1, chain2) - for atom in res2.get_atoms() - if atom.id in atom_types and atom.id in [a.id for a in res1.get_atoms()] - ] - return atoms1, atoms2 - - def get_interacting_pairs(distances, threshold): interacting_pairs = np.nonzero(np.asarray(distances) < threshold) return tuple(interacting_pairs[0]), tuple(interacting_pairs[1]) @lru_cache -def get_interface_atoms( - interacting_pairs, - model_chains, - ref_chains, - atom_types=[], +def subset_atoms( + mod_chain, + ref_chain, + atom_types, + residue_subset=None, + what="", ): - ref_interface = [] - mod_interface = [] + mod_atoms = [] + ref_atoms = [] - ref_residues_group1 = [res for res in ref_chains[0]] - ref_residues_group2 = [res for res in ref_chains[1]] + mod_residues = [res for res in mod_chain] + ref_residues = [res for res in ref_chain] - mod_residues_group1 = [res for res in model_chains[0]] - mod_residues_group2 = [res for res in model_chains[1]] # remove duplicate residues - interface_residues_group1 = set(interacting_pairs[0]) - interface_residues_group2 = set(interacting_pairs[1]) - - for i in interface_residues_group1: - ref_atoms = [atom for atom in ref_residues_group1[i].get_atoms()] - mod_atoms = [atom for atom in mod_residues_group1[i].get_atoms()] - ref_atoms_ids = [atom.id for atom in ref_atoms] - mod_atoms_ids = [atom.id for atom in mod_atoms] - ref_interface += [ - atom.coord - for atom in ref_atoms - if atom.id in atom_types and atom.id in mod_atoms_ids - ] - mod_interface += [ - atom.coord - for atom in mod_atoms - if atom.id in atom_types and atom.id in ref_atoms_ids - ] + residue_subset = set(residue_subset) if residue_subset else range(len(mod_residues)) - for j in interface_residues_group2: - ref_atoms = [atom for atom in ref_residues_group2[j].get_atoms()] - mod_atoms = [atom for atom in mod_residues_group2[j].get_atoms()] - ref_atoms_ids = [atom.id for atom in ref_atoms] - mod_atoms_ids = [atom.id for atom in mod_atoms] - ref_interface += [ - atom.coord - for atom in ref_atoms - if atom.id in atom_types and atom.id in mod_atoms_ids - ] - mod_interface += [ - atom.coord - for atom in mod_atoms - if atom.id in atom_types and atom.id in ref_atoms_ids - ] + for i in residue_subset: + mod_res_atoms = list(mod_residues[i].get_atoms()) + ref_res_atoms = list(ref_residues[i].get_atoms()) + mod_res_atoms_ids = [atom.id for atom in mod_res_atoms] + ref_res_atoms_ids = [atom.id for atom in ref_res_atoms] - return np.asarray(mod_interface), np.asarray(ref_interface) + for atom_type in atom_types: + try: + mod_i = mod_res_atoms_ids.index(atom_type) + ref_i = ref_res_atoms_ids.index(atom_type) + mod_atoms += [mod_res_atoms[mod_i].coord] + ref_atoms += [ref_res_atoms[ref_i].coord] + except: + continue + + return mod_atoms, ref_atoms @lru_cache @@ -662,12 +641,7 @@ def run_on_all_native_interfaces( ) info["chain_map"] = chain_map # diagnostics result_mapping["".join(chain_pair)] = info - total_dockq = sum( - [ - result["DockQ"] - for result in result_mapping.values() - ] - ) + total_dockq = sum([result["DockQ"] for result in result_mapping.values()]) return result_mapping, total_dockq @@ -713,21 +687,25 @@ def group_chains( try: qc = query_structure[query_chain] except KeyError: - logging.error(f"""The specified model chain {query_chain} is not found in the PDB structure. + logging.error( + f"""The specified model chain {query_chain} is not found in the PDB structure. This is possibly due to using the wrong chain identifier in --mapping, or forgetting to specify --small_molecule if this is a HETATM chain. If working with mmCIF files, make sure you use the right chain identifier. - """) + """ + ) print(traceback.format_exc()) sys.exit(1) try: rc = ref_structure[ref_chain] except KeyError: - logging.error(f"""The specified native chain {ref_chain} is not found in the PDB structure. + logging.error( + f"""The specified native chain {ref_chain} is not found in the PDB structure. This is possibly due to using the wrong chain identifier in --mapping, or forgetting to specify --small_molecule if this is a HETATM chain. If working with mmCIF files, make sure you use the right chain identifier. - """) + """ + ) het_qc = qc.is_het het_rc = rc.is_het @@ -753,9 +731,11 @@ def group_chains( ] if mismatch_dict: - logging.warning(f"""Some chains have a limited number of sequence mismatches and are treated as non-homologous. + logging.warning( + f"""Some chains have a limited number of sequence mismatches and are treated as non-homologous. Try increasing the --allowed_mismatches for the following: {", ".join(f"Model chain {c[1]}, native chain {c[0]}: {m} mismatches" for c, m in mismatch_dict.items())} -if they should be treated as homologous.""") +if they should be treated as homologous.""" + ) if chains_without_match: logging.error( @@ -763,7 +743,6 @@ def group_chains( ) sys.exit(1) - return chain_clusters, reverse_map @@ -776,7 +755,9 @@ def format_mapping(mapping_str, small_molecule=None): model_mapping, native_mapping = mapping_str.split(":") if not native_mapping: - logging.error("When using --mapping, native chains must be set (e.g. ABC:ABC or :ABC)") + logging.error( + "When using --mapping, native chains must be set (e.g. ABC:ABC or :ABC)" + ) sys.exit() else: # :ABC or *:ABC only use those natives chains, permute model chains @@ -834,8 +815,10 @@ def count_chain_combinations(chain_clusters): ] ) except ValueError: - logging.error("""Couldn't find a match between each model-native chain specified in the mapping. -Make sure that all chains in your model have a homologous chain in the native, or specify the right subset of chains with --mapping""") + logging.error( + """Couldn't find a match between each model-native chain specified in the mapping. +Make sure that all chains in your model have a homologous chain in the native, or specify the right subset of chains with --mapping""" + ) sys.exit() return number_of_combinations @@ -978,7 +961,9 @@ def main(): best_result, best_dockq = run_chain_map(best_mapping) if not best_result: - logging.error("Could not find interfaces in the native model. Please double check the inputs or select different chains with the --mapping flag.") + logging.error( + "Could not find interfaces in the native model. Please double check the inputs or select different chains with the --mapping flag." + ) sys.exit(1) info = dict() @@ -989,7 +974,7 @@ def main(): info["GlobalDockQ"] = best_dockq / len(best_result) info["best_mapping"] = best_mapping info["best_mapping_str"] = f"{format_mapping_string(best_mapping)}" - + if args.json: with open(args.json, "w") as fp: json.dump(info, fp) @@ -1030,7 +1015,12 @@ def print_results( ) hetname = f" ({results['is_het']})" if results["is_het"] else "" score_str = " ".join( - [f"{item} {results[item]:.3f}" if item != "clashes" else f"{item} {results[item]}" for item in reported_measures] + [ + f"{item} {results[item]:.3f}" + if item != "clashes" + else f"{item} {results[item]}" + for item in reported_measures + ] ) print( f"{score_str} mapping {results['chain1']}{results['chain2']}:{chains[0]}{chains[1]}{hetname} {info['model']} {results['chain1']} {results['chain2']} -> {info['native']} {chains[0]} {chains[1]}" @@ -1061,7 +1051,12 @@ def print_results( print(f"\tModel chains: {results['chain1']}, {results['chain2']}") print( "\n".join( - [f"\t{item}: {results[item]:.3f}" if item != "clashes" else f"\t{item}: {results[item]}" for item in reported_measures] + [ + f"\t{item}: {results[item]:.3f}" + if item != "clashes" + else f"\t{item}: {results[item]}" + for item in reported_measures + ] ) )