Skip to content

Commit

Permalink
unify atom subsetting function
Browse files Browse the repository at this point in the history
  • Loading branch information
clami66 committed Aug 28, 2024
1 parent f0a77da commit c5986d4
Showing 1 changed file with 50 additions and 21 deletions.
71 changes: 50 additions & 21 deletions src/DockQ/DockQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,8 @@ def calc_sym_corrected_lrmsd(
aligned_sample_receptor, aligned_ref_receptor = get_aligned_residues(
sample_receptor, ref_receptor, receptor_alignment
)
# get a copy of each structure, then only keep backbone atoms
sample_interface_atoms, ref_interface_atoms = get_interface_atoms(
(receptor_interface, ()),
(aligned_sample_receptor, ()),
(aligned_ref_receptor, ()),
atom_types=BACKBONE_ATOMS,
)

sample_interface_atoms, ref_interface_atoms = subset_atoms(aligned_sample_receptor, aligned_ref_receptor, atom_types=BACKBONE_ATOMS, residue_subset=receptor_interface)

sample_ligand_atoms_ids = [atom.id for atom in sample_ligand.get_atoms()]
sample_ligand_atoms_ele = [atom.element for atom in sample_ligand.get_atoms()]
Expand All @@ -230,7 +225,7 @@ def calc_sym_corrected_lrmsd(

# Set to align on receptor interface
super_imposer = SVDSuperimposer()
super_imposer.set(ref_interface_atoms, sample_interface_atoms)
super_imposer.set(np.asarray(ref_interface_atoms), np.asarray(sample_interface_atoms))
super_imposer.run()
rot, tran = super_imposer.get_rotran()

Expand Down Expand Up @@ -323,13 +318,13 @@ def calc_DockQ(
ref_res_distances,
threshold=interface_threshold ** 2,
)
# get a copy of each structure, then only keep backbone atoms
sample_interface_atoms, ref_interface_atoms = get_interface_atoms(
interacting_pairs,
(aligned_sample_1, aligned_sample_2),
(aligned_ref_1, aligned_ref_2),
atom_types=BACKBONE_ATOMS,
)

sample_interface_atoms1, ref_interface_atoms1 = subset_atoms(aligned_sample_1, aligned_ref_1, atom_types=BACKBONE_ATOMS, residue_subset=interacting_pairs[0])
sample_interface_atoms2, ref_interface_atoms2 = subset_atoms(aligned_sample_2, aligned_ref_2, atom_types=BACKBONE_ATOMS, residue_subset=interacting_pairs[1])

sample_interface_atoms = np.asarray(sample_interface_atoms1 + sample_interface_atoms2)
ref_interface_atoms = np.asarray(ref_interface_atoms1 + ref_interface_atoms2)

super_imposer = SVDSuperimposer()
super_imposer.set(sample_interface_atoms, ref_interface_atoms)
super_imposer.run()
Expand All @@ -354,19 +349,19 @@ def calc_DockQ(
else ("ligand", "receptor")
)

receptor_atoms_native, receptor_atoms_sample = get_atoms_per_residue(
receptor_chains, what="receptor", atom_types=BACKBONE_ATOMS
receptor_atoms_native, receptor_atoms_sample = subset_atoms(
receptor_chains[0], receptor_chains[1], atom_types=BACKBONE_ATOMS
)
ligand_atoms_native, ligand_atoms_sample = get_atoms_per_residue(ligand_chains, what="ligand", atom_types=BACKBONE_ATOMS)
ligand_atoms_native, ligand_atoms_sample = subset_atoms(ligand_chains[0], ligand_chains[1], atom_types=BACKBONE_ATOMS)
# Set to align on receptor
super_imposer.set(receptor_atoms_native, receptor_atoms_sample)
super_imposer.set(np.asarray(receptor_atoms_native), np.asarray(receptor_atoms_sample))
super_imposer.run()

rot, tran = super_imposer.get_rotran()
rotated_sample_atoms = np.dot(ligand_atoms_sample, rot) + tran
rotated_sample_atoms = np.dot(np.asarray(ligand_atoms_sample), rot) + tran

lrms = super_imposer._rms(
ligand_atoms_native, rotated_sample_atoms
np.asarray(ligand_atoms_native), rotated_sample_atoms
) # using the private _rms function which does not superimpose

info = {}
Expand Down Expand Up @@ -563,6 +558,40 @@ def get_interface_atoms(
return np.asarray(mod_interface), np.asarray(ref_interface)


@lru_cache
def subset_atoms(
mod_chain,
ref_chain,
atom_types,
residue_subset=None,
):
mod_atoms = []
ref_atoms = []

mod_residues = [res for res in mod_chain]
ref_residues = [res for res in ref_chain]

# remove duplicate residues
residue_subset = set(residue_subset) if residue_subset else range(len(mod_residues))

for i in residue_subset:
mod_res_atoms = list(mod_residues[i].get_atoms())
ref_res_atoms = list(ref_residues[i].get_atoms())
mod_res_atoms_ids = [atom.id for atom in mod_res_atoms]
ref_res_atoms_ids = [atom.id for atom in ref_res_atoms]

for atom_type in atom_types:
try:
mod_i = mod_res_atoms_ids.index(atom_type)
ref_i = ref_res_atoms_ids.index(atom_type)
mod_atoms += [mod_res_atoms[mod_i].coord]
ref_atoms += [ref_res_atoms[ref_i].coord]
except:
continue

return mod_atoms, ref_atoms


@lru_cache
def run_on_chains(
model_chains,
Expand Down

0 comments on commit c5986d4

Please sign in to comment.