Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

various fixes for handling ensembles with selections as ref #1807

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion prody/atomic/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def sortAtoms(atoms, label, reverse=False):
return AtomMap(ag, sort, acsi)


def sliceAtoms(atoms, select):
def sliceAtoms(atoms, select, allowSame=False):
"""Slice *atoms* using the selection defined by *select*.

:arg atoms: atoms to be selected from
Expand All @@ -297,6 +297,8 @@ def sliceAtoms(atoms, select):
"""

if atoms == select:
if allowSame:
return atoms._getSubset('all'), atoms.all
raise ValueError('atoms and select arguments are the same')

try:
Expand Down
51 changes: 43 additions & 8 deletions prody/ensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def getAtoms(self, selected=True):
return self._atoms
return self._atoms[self._indices]

def setAtoms(self, atoms):
def setAtoms(self, atoms, allowSame=False):
"""Set *atoms* or specify a selection of atoms to be considered in
calculations and coordinate requests. When a selection is set,
corresponding subset of coordinates will be considered in, for
Expand All @@ -238,7 +238,7 @@ def setAtoms(self, atoms):

n_atoms = self._n_atoms
if n_atoms:
if atoms.numAtoms() > n_atoms:
if atoms.numAtoms() > n_atoms and atoms.ca.numAtoms() > n_atoms:
raise ValueError('atoms must be same size or smaller than '
'the ensemble')

Expand All @@ -261,11 +261,19 @@ def setAtoms(self, atoms):
ag = atoms.getAtomGroup()
except AttributeError:
ag = atoms
if ag.numAtoms() != n_atoms:
try:
self_ag = self._atoms.getAtomGroup()
except AttributeError:
self_ag = self._atoms
try:
self_ag_n_atoms = self_ag.numAtoms()
except AttributeError:
self_ag_n_atoms = 0
if ag.numAtoms() != n_atoms and ag.numAtoms() != self_ag_n_atoms and self_ag_n_atoms != 0:
raise ValueError('size mismatch between this ensemble ({0} atoms) and atoms ({1} atoms)'
.format(n_atoms, ag.numAtoms()))
self._atoms = ag
self._indices, _ = sliceAtoms(self._atoms, atoms)
self._indices, _ = sliceAtoms(self._atoms, atoms, allowSame=allowSame)

else: # if assigning atoms to a new ensemble
self._n_atoms = atoms.numAtoms()
Expand Down Expand Up @@ -294,12 +302,32 @@ def getCoords(self, selected=True):
return None
if self._indices is None or not selected:
return self._coords.copy()
return self._coords[self._indices].copy()

selids = self._indices
if self.hasSelectionIssue():
selids = self.getIndices(calphas=True)
return self._coords[selids].copy()

def getIndices(self):
def getIndices(self, calphas=False):
"""Returns a copy of indices of selected columns"""

if calphas:
return array([list(self._atoms.ca.getIndices()).index(idx)
for idx in self._indices])
return copy(self._indices)

def hasSelectionIssue(self):
if self._atoms is None or self._atoms.ca is None:
return False

selids = self._indices
if selids is None:
return False

if (selids.max() > self._coords.shape[0]
and set(selids).issubset(set(self._atoms.ca.getIndices()))):
return True
return False

def setIndices(self, value):
if not isListLike(value):
Expand Down Expand Up @@ -359,10 +387,17 @@ def getWeights(self, selected=True):
return None
if self._indices is None or not selected:
return self._weights.copy()


if self.hasSelectionIssue():
selids = self.getIndices(calphas=True)
else:
selids = self._indices

if self._weights.ndim == 2:
return self._weights[self._indices].copy()
return self._weights[selids].copy()
else:
return self._weights[:, self._indices].copy()
return self._weights[:, selids].copy()

def _getWeights(self, selected=True):

Expand Down
2 changes: 1 addition & 1 deletion prody/ensemble/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def loadEnsemble(filename, **kwargs):
data[key] = arr
else:
atoms = None
ensemble.setAtoms(atoms)
ensemble.setAtoms(atoms, allowSame=True)

if '_indices' in attr_dict:
indices = attr_dict['_indices']
Expand Down
5 changes: 5 additions & 0 deletions prody/ensemble/pdbensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ def getMSA(self, indices=None, selected=True):

atom_indices = self._indices if selected else slice(None, None, None)
indices = indices if indices is not None else slice(None, None, None)

if self.hasSelectionIssue():
atom_indices = self.getIndices(calphas=True)

return self._msa[indices, atom_indices]

Expand Down Expand Up @@ -491,6 +494,8 @@ def getCoordsets(self, indices=None, selected=True):
confs[i, which] = coords[which]
else:
selids = self._indices
if self.hasSelectionIssue():
selids = self.getIndices(calphas=True)
coords = coords[selids]
confs = self._confs[indices, selids].copy()
for i, w in enumerate(self._weights[indices]):
Expand Down
Loading