Skip to content

Commit

Permalink
Merge pull request #116 from haydn-jones/master
Browse files Browse the repository at this point in the history
Replace recursion with stack (fix #115)
  • Loading branch information
MarioKrenn6240 authored Jul 15, 2024
2 parents c5e2d78 + 4b70c1e commit 1fe121b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 29 deletions.
68 changes: 39 additions & 29 deletions selfies/utils/smiles_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,39 +442,49 @@ def _derive_smiles_from_fragment(
root,
ring_log,
attribution_maps, attribution_index=0):
curr_atom, curr = mol.get_atom(root), root
token = atom_to_smiles(curr_atom)
derived.append(token)
attribution_maps.append(AttributionMap(
_strlen(derived) - 1 + attribution_index,
token, mol.get_attribution(curr_atom)))

out_bonds = mol.get_out_dirbonds(curr)
for i, bond in enumerate(out_bonds):
if bond.ring_bond:
token = bond_to_smiles(bond)
derived.append(token)
attribution_maps.append(AttributionMap(
_strlen(derived) - 1 + attribution_index,
token, mol.get_attribution(bond)))
ends = (min(bond.src, bond.dst), max(bond.src, bond.dst))
rnum = ring_log.setdefault(ends, len(ring_log) + 1)
if rnum >= 10:
derived.append("%")
derived.append(str(rnum))
stack = [(root, 0, len(mol.get_out_dirbonds(root)), False)]

else:
if i < len(out_bonds) - 1:
derived.append("(")
while stack:
curr, bond_index, total_bonds, needs_closing = stack[-1]
curr_atom = mol.get_atom(curr)

token = bond_to_smiles(bond)
if bond_index == 0:
token = atom_to_smiles(curr_atom)
derived.append(token)
attribution_maps.append(AttributionMap(
_strlen(derived) - 1 + attribution_index,
token, mol.get_attribution(bond)))
_derive_smiles_from_fragment(
derived, mol, bond.dst, ring_log,
attribution_maps, attribution_index)
if i < len(out_bonds) - 1:
token, mol.get_attribution(curr_atom)))

out_bonds = mol.get_out_dirbonds(curr)

if bond_index < total_bonds:
bond = out_bonds[bond_index]
bond_attribution = mol.get_attribution(bond)
stack[-1] = (curr, bond_index + 1, total_bonds, needs_closing)

if bond.ring_bond:
token = bond_to_smiles(bond)
derived.append(token)
attribution_maps.append(AttributionMap(
_strlen(derived) - 1 + attribution_index,
token, bond_attribution))
ends = (min(bond.src, bond.dst), max(bond.src, bond.dst))
rnum = ring_log.setdefault(ends, len(ring_log) + 1)
if rnum >= 10:
derived.append("%")
derived.append(str(rnum))
else:
if bond_index < total_bonds - 1:
derived.append("(")

token = bond_to_smiles(bond)
derived.append(token)
attribution_maps.append(AttributionMap(
_strlen(derived) - 1 + attribution_index,
token, bond_attribution))
stack.append((bond.dst, 0, len(mol.get_out_dirbonds(bond.dst)), bond_index < total_bonds - 1))
else:
stack.pop()
if needs_closing:
derived.append(")")
return attribution_maps
9 changes: 9 additions & 0 deletions tests/test_specific_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,12 @@ def test_old_symbols():
sf.decoder(long_s, compatible=True)
except Exception:
assert False

def test_large_selfies_decoding():
"""Test that we can decode extremely large SELFIES strings (used to cause a RecursionError)
"""

large_selfies = "[C]" * 1024
expected_smiles = "C" * 1024

assert decode_eq(large_selfies, expected_smiles)

0 comments on commit 1fe121b

Please sign in to comment.