diff --git a/src/DockQ/DockQ.py b/src/DockQ/DockQ.py index 98dc38f..44f1316 100755 --- a/src/DockQ/DockQ.py +++ b/src/DockQ/DockQ.py @@ -199,13 +199,8 @@ def calc_sym_corrected_lrmsd( aligned_sample_receptor, aligned_ref_receptor = get_aligned_residues( sample_receptor, ref_receptor, receptor_alignment ) - # get a copy of each structure, then only keep backbone atoms - sample_interface_atoms, ref_interface_atoms = get_interface_atoms( - (receptor_interface, ()), - (aligned_sample_receptor, ()), - (aligned_ref_receptor, ()), - atom_types=BACKBONE_ATOMS, - ) + + sample_interface_atoms, ref_interface_atoms = subset_atoms(aligned_sample_receptor, aligned_ref_receptor, atom_types=BACKBONE_ATOMS, residue_subset=receptor_interface) sample_ligand_atoms_ids = [atom.id for atom in sample_ligand.get_atoms()] sample_ligand_atoms_ele = [atom.element for atom in sample_ligand.get_atoms()] @@ -230,7 +225,7 @@ def calc_sym_corrected_lrmsd( # Set to align on receptor interface super_imposer = SVDSuperimposer() - super_imposer.set(ref_interface_atoms, sample_interface_atoms) + super_imposer.set(np.asarray(ref_interface_atoms), np.asarray(sample_interface_atoms)) super_imposer.run() rot, tran = super_imposer.get_rotran() @@ -323,13 +318,13 @@ def calc_DockQ( ref_res_distances, threshold=interface_threshold ** 2, ) - # get a copy of each structure, then only keep backbone atoms - sample_interface_atoms, ref_interface_atoms = get_interface_atoms( - interacting_pairs, - (aligned_sample_1, aligned_sample_2), - (aligned_ref_1, aligned_ref_2), - atom_types=BACKBONE_ATOMS, - ) + + sample_interface_atoms1, ref_interface_atoms1 = subset_atoms(aligned_sample_1, aligned_ref_1, atom_types=BACKBONE_ATOMS, residue_subset=interacting_pairs[0]) + sample_interface_atoms2, ref_interface_atoms2 = subset_atoms(aligned_sample_2, aligned_ref_2, atom_types=BACKBONE_ATOMS, residue_subset=interacting_pairs[1]) + + sample_interface_atoms = np.asarray(sample_interface_atoms1 + sample_interface_atoms2) + ref_interface_atoms = np.asarray(ref_interface_atoms1 + ref_interface_atoms2) + super_imposer = SVDSuperimposer() super_imposer.set(sample_interface_atoms, ref_interface_atoms) super_imposer.run() @@ -354,19 +349,19 @@ def calc_DockQ( else ("ligand", "receptor") ) - receptor_atoms_native, receptor_atoms_sample = get_atoms_per_residue( - receptor_chains, what="receptor", atom_types=BACKBONE_ATOMS + receptor_atoms_native, receptor_atoms_sample = subset_atoms( + receptor_chains[0], receptor_chains[1], atom_types=BACKBONE_ATOMS ) - ligand_atoms_native, ligand_atoms_sample = get_atoms_per_residue(ligand_chains, what="ligand", atom_types=BACKBONE_ATOMS) + ligand_atoms_native, ligand_atoms_sample = subset_atoms(ligand_chains[0], ligand_chains[1], atom_types=BACKBONE_ATOMS) # Set to align on receptor - super_imposer.set(receptor_atoms_native, receptor_atoms_sample) + super_imposer.set(np.asarray(receptor_atoms_native), np.asarray(receptor_atoms_sample)) super_imposer.run() rot, tran = super_imposer.get_rotran() - rotated_sample_atoms = np.dot(ligand_atoms_sample, rot) + tran + rotated_sample_atoms = np.dot(np.asarray(ligand_atoms_sample), rot) + tran lrms = super_imposer._rms( - ligand_atoms_native, rotated_sample_atoms + np.asarray(ligand_atoms_native), rotated_sample_atoms ) # using the private _rms function which does not superimpose info = {} @@ -563,6 +558,40 @@ def get_interface_atoms( return np.asarray(mod_interface), np.asarray(ref_interface) +@lru_cache +def subset_atoms( + mod_chain, + ref_chain, + atom_types, + residue_subset=None, +): + mod_atoms = [] + ref_atoms = [] + + mod_residues = [res for res in mod_chain] + ref_residues = [res for res in ref_chain] + + # remove duplicate residues + residue_subset = set(residue_subset) if residue_subset else range(len(mod_residues)) + + for i in residue_subset: + mod_res_atoms = list(mod_residues[i].get_atoms()) + ref_res_atoms = list(ref_residues[i].get_atoms()) + mod_res_atoms_ids = [atom.id for atom in mod_res_atoms] + ref_res_atoms_ids = [atom.id for atom in ref_res_atoms] + + for atom_type in atom_types: + try: + mod_i = mod_res_atoms_ids.index(atom_type) + ref_i = ref_res_atoms_ids.index(atom_type) + mod_atoms += [mod_res_atoms[mod_i].coord] + ref_atoms += [ref_res_atoms[ref_i].coord] + except: + continue + + return mod_atoms, ref_atoms + + @lru_cache def run_on_chains( model_chains,