From fa68282771689ceeedf30abdaabb979e9b2cdf3f Mon Sep 17 00:00:00 2001 From: Lily Wang Date: Sat, 26 Oct 2024 13:19:11 +1100 Subject: [PATCH] remove "adding" fixes -- keep only "deletion" fixes --- package/MDAnalysis/core/topologyattrs.py | 20 +++++++++++++------ .../MDAnalysisTests/topology/test_pdb.py | 4 ++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/package/MDAnalysis/core/topologyattrs.py b/package/MDAnalysis/core/topologyattrs.py index 57e5cc84f5..e5cf003b5b 100644 --- a/package/MDAnalysis/core/topologyattrs.py +++ b/package/MDAnalysis/core/topologyattrs.py @@ -3042,12 +3042,19 @@ class _Connection(AtomAttr, metaclass=_ConnectionTopologyAttrMeta): @_check_connection_values def __init__(self, values, types=None, guessed=False, order=None): - self.values = [] - self.types = [] - self._guessed = [] - self.order = [] + self.values = values + if types is None: + types = [None] * len(values) + self.types = types + if guessed in (True, False): + # if single value passed, multiply this across + # all bonds + guessed = [guessed] * len(values) + self._guessed = guessed + if order is None: + order = [None] * len(values) + self.order = order self._cache = dict() - self._add_bonds(values, types, guessed, order) def copy(self): """Return a deepcopy of this attribute""" @@ -3111,8 +3118,9 @@ def _add_bonds(self, values, types=None, guessed=True, order=None): if order is None: order = itertools.cycle((None,)) + existing = set(self.values) for v, t, g, o in zip(values, types, guessed, order): - if v not in self.values: + if v not in existing: self.values.append(v) self.types.append(t) self._guessed.append(g) diff --git a/testsuite/MDAnalysisTests/topology/test_pdb.py b/testsuite/MDAnalysisTests/topology/test_pdb.py index 93115a92d3..c176d50be1 100644 --- a/testsuite/MDAnalysisTests/topology/test_pdb.py +++ b/testsuite/MDAnalysisTests/topology/test_pdb.py @@ -146,7 +146,7 @@ def parse(): struc = parse() assert hasattr(struc, 'bonds') - assert len(struc.bonds.values) == 2 + assert len(struc.bonds.values) == 4 def test_single_conect(): @@ -158,7 +158,7 @@ def parse(): with pytest.warns(UserWarning): struc = parse() assert hasattr(struc, 'bonds') - assert len(struc.bonds.values) == 1 + assert len(struc.bonds.values) == 2 def test_new_chainid_new_res():