diff --git a/src/DockQ/DockQ.py b/src/DockQ/DockQ.py index 6f9dca9..d48f41f 100755 --- a/src/DockQ/DockQ.py +++ b/src/DockQ/DockQ.py @@ -55,7 +55,10 @@ def parse_args(): "--short", default=False, action="store_true", help="Short output" ) parser.add_argument( - "--json", default=None, metavar="out.json", help="Write outputs to a chosen json file" + "--json", + default=None, + metavar="out.json", + help="Write outputs to a chosen json file", ) parser.add_argument( "--verbose", "-v", default=False, action="store_true", help="Verbose output" @@ -200,7 +203,13 @@ def calc_sym_corrected_lrmsd( sample_receptor, ref_receptor, receptor_alignment ) - sample_interface_atoms, ref_interface_atoms = subset_atoms(aligned_sample_receptor, aligned_ref_receptor, atom_types=BACKBONE_ATOMS, residue_subset=receptor_interface) + sample_interface_atoms, ref_interface_atoms = subset_atoms( + aligned_sample_receptor, + aligned_ref_receptor, + atom_types=BACKBONE_ATOMS, + residue_subset=receptor_interface, + what="receptor", + ) sample_ligand_atoms_ids = [atom.id for atom in sample_ligand.get_atoms()] sample_ligand_atoms_ele = [atom.element for atom in sample_ligand.get_atoms()] @@ -225,7 +234,9 @@ def calc_sym_corrected_lrmsd( # Set to align on receptor interface super_imposer = SVDSuperimposer() - super_imposer.set(np.asarray(ref_interface_atoms), np.asarray(sample_interface_atoms)) + super_imposer.set( + np.asarray(ref_interface_atoms), np.asarray(sample_interface_atoms) + ) super_imposer.run() rot, tran = super_imposer.get_rotran() @@ -318,11 +329,23 @@ def calc_DockQ( ref_res_distances, threshold=interface_threshold ** 2, ) - - sample_interface_atoms1, ref_interface_atoms1 = subset_atoms(aligned_sample_1, aligned_ref_1, atom_types=BACKBONE_ATOMS, residue_subset=interacting_pairs[0]) - sample_interface_atoms2, ref_interface_atoms2 = subset_atoms(aligned_sample_2, aligned_ref_2, atom_types=BACKBONE_ATOMS, residue_subset=interacting_pairs[1]) - - sample_interface_atoms = np.asarray(sample_interface_atoms1 + sample_interface_atoms2) + + sample_interface_atoms1, ref_interface_atoms1 = subset_atoms( + aligned_sample_1, + aligned_ref_1, + atom_types=BACKBONE_ATOMS, + residue_subset=interacting_pairs[0], + ) + sample_interface_atoms2, ref_interface_atoms2 = subset_atoms( + aligned_sample_2, + aligned_ref_2, + atom_types=BACKBONE_ATOMS, + residue_subset=interacting_pairs[1], + ) + + sample_interface_atoms = np.asarray( + sample_interface_atoms1 + sample_interface_atoms2 + ) ref_interface_atoms = np.asarray(ref_interface_atoms1 + ref_interface_atoms2) super_imposer = SVDSuperimposer() @@ -350,11 +373,18 @@ def calc_DockQ( ) receptor_atoms_native, receptor_atoms_sample = subset_atoms( - receptor_chains[0], receptor_chains[1], atom_types=BACKBONE_ATOMS - ) - ligand_atoms_native, ligand_atoms_sample = subset_atoms(ligand_chains[0], ligand_chains[1], atom_types=BACKBONE_ATOMS) + receptor_chains[0], + receptor_chains[1], + atom_types=BACKBONE_ATOMS, + what="receptor", + ) + ligand_atoms_native, ligand_atoms_sample = subset_atoms( + ligand_chains[0], ligand_chains[1], atom_types=BACKBONE_ATOMS, what="ligand" + ) # Set to align on receptor - super_imposer.set(np.asarray(receptor_atoms_native), np.asarray(receptor_atoms_sample)) + super_imposer.set( + np.asarray(receptor_atoms_native), np.asarray(receptor_atoms_sample) + ) super_imposer.run() rot, tran = super_imposer.get_rotran() @@ -486,13 +516,14 @@ def subset_atoms( ref_chain, atom_types, residue_subset=None, + what="", ): mod_atoms = [] ref_atoms = [] mod_residues = [res for res in mod_chain] ref_residues = [res for res in ref_chain] - + # remove duplicate residues residue_subset = set(residue_subset) if residue_subset else range(len(mod_residues)) @@ -610,12 +641,7 @@ def run_on_all_native_interfaces( ) info["chain_map"] = chain_map # diagnostics result_mapping["".join(chain_pair)] = info - total_dockq = sum( - [ - result["DockQ"] - for result in result_mapping.values() - ] - ) + total_dockq = sum([result["DockQ"] for result in result_mapping.values()]) return result_mapping, total_dockq @@ -661,21 +687,25 @@ def group_chains( try: qc = query_structure[query_chain] except KeyError: - logging.error(f"""The specified model chain {query_chain} is not found in the PDB structure. + logging.error( + f"""The specified model chain {query_chain} is not found in the PDB structure. This is possibly due to using the wrong chain identifier in --mapping, or forgetting to specify --small_molecule if this is a HETATM chain. If working with mmCIF files, make sure you use the right chain identifier. - """) + """ + ) print(traceback.format_exc()) sys.exit(1) try: rc = ref_structure[ref_chain] except KeyError: - logging.error(f"""The specified native chain {ref_chain} is not found in the PDB structure. + logging.error( + f"""The specified native chain {ref_chain} is not found in the PDB structure. This is possibly due to using the wrong chain identifier in --mapping, or forgetting to specify --small_molecule if this is a HETATM chain. If working with mmCIF files, make sure you use the right chain identifier. - """) + """ + ) het_qc = qc.is_het het_rc = rc.is_het @@ -701,9 +731,11 @@ def group_chains( ] if mismatch_dict: - logging.warning(f"""Some chains have a limited number of sequence mismatches and are treated as non-homologous. + logging.warning( + f"""Some chains have a limited number of sequence mismatches and are treated as non-homologous. Try increasing the --allowed_mismatches for the following: {", ".join(f"Model chain {c[1]}, native chain {c[0]}: {m} mismatches" for c, m in mismatch_dict.items())} -if they should be treated as homologous.""") +if they should be treated as homologous.""" + ) if chains_without_match: logging.error( @@ -711,7 +743,6 @@ def group_chains( ) sys.exit(1) - return chain_clusters, reverse_map @@ -724,7 +755,9 @@ def format_mapping(mapping_str, small_molecule=None): model_mapping, native_mapping = mapping_str.split(":") if not native_mapping: - logging.error("When using --mapping, native chains must be set (e.g. ABC:ABC or :ABC)") + logging.error( + "When using --mapping, native chains must be set (e.g. ABC:ABC or :ABC)" + ) sys.exit() else: # :ABC or *:ABC only use those natives chains, permute model chains @@ -782,8 +815,10 @@ def count_chain_combinations(chain_clusters): ] ) except ValueError: - logging.error("""Couldn't find a match between each model-native chain specified in the mapping. -Make sure that all chains in your model have a homologous chain in the native, or specify the right subset of chains with --mapping""") + logging.error( + """Couldn't find a match between each model-native chain specified in the mapping. +Make sure that all chains in your model have a homologous chain in the native, or specify the right subset of chains with --mapping""" + ) sys.exit() return number_of_combinations @@ -926,7 +961,9 @@ def main(): best_result, best_dockq = run_chain_map(best_mapping) if not best_result: - logging.error("Could not find interfaces in the native model. Please double check the inputs or select different chains with the --mapping flag.") + logging.error( + "Could not find interfaces in the native model. Please double check the inputs or select different chains with the --mapping flag." + ) sys.exit(1) info = dict() @@ -937,7 +974,7 @@ def main(): info["GlobalDockQ"] = best_dockq / len(best_result) info["best_mapping"] = best_mapping info["best_mapping_str"] = f"{format_mapping_string(best_mapping)}" - + if args.json: with open(args.json, "w") as fp: json.dump(info, fp) @@ -978,7 +1015,12 @@ def print_results( ) hetname = f" ({results['is_het']})" if results["is_het"] else "" score_str = " ".join( - [f"{item} {results[item]:.3f}" if item != "clashes" else f"{item} {results[item]}" for item in reported_measures] + [ + f"{item} {results[item]:.3f}" + if item != "clashes" + else f"{item} {results[item]}" + for item in reported_measures + ] ) print( f"{score_str} mapping {results['chain1']}{results['chain2']}:{chains[0]}{chains[1]}{hetname} {info['model']} {results['chain1']} {results['chain2']} -> {info['native']} {chains[0]} {chains[1]}" @@ -1009,7 +1051,12 @@ def print_results( print(f"\tModel chains: {results['chain1']}, {results['chain2']}") print( "\n".join( - [f"\t{item}: {results[item]:.3f}" if item != "clashes" else f"\t{item}: {results[item]}" for item in reported_measures] + [ + f"\t{item}: {results[item]:.3f}" + if item != "clashes" + else f"\t{item}: {results[item]}" + for item in reported_measures + ] ) )