Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-recursive composite subsets #2140

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 72 additions & 18 deletions glue/core/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,35 +1001,96 @@ class CompositeSubsetState(SubsetState):
"""
The base class for combinations of subset states.
"""
def _traverse_state_tree(self, substate_func, unary_func, binary_func):
"""
A helper function that traverses the entire state tree iteratively.

Use this instead of writing recursive function calls to prevent stack overflow.

Parameters
----------
substate_func : callable
A callable object that takes one argument of a `SubsetState` object.
This function will be applied to all substates that are not CompositeSubsetStates.
This callable should return a result.
unary_func : callable
A callable object that takes two arguments, an `InvertState` and one input argument.
The input argument is the same as the result of the recursive function call on
the state1 field of the InvertState object.
This callable should return a result.
binary_func : callable
A callable object that takes three agruments, an `InvertState` and two arguments.
The first input argument is the same as the result of the recursive function call on
the state1 field, and the second argument is the result of recursion on state2.
This callable should return a result.
"""

visitation_stack = [(self, False)]
results_stack = []

while visitation_stack:
state, visited = visitation_stack.pop()
if isinstance(state, CompositeSubsetState):
if visited:
if state.op is operator.invert:
results_stack[-1] = unary_func(state, results_stack[-1])
else:
rhs = results_stack.pop()
results_stack[-1] = binary_func(state, results_stack[-1], rhs)
else:
visitation_stack.append((state, True))
if state.op is not operator.invert:
visitation_stack.append((state.state2, False))
visitation_stack.append((state.state1, False))

else:
results_stack.append(substate_func(state))

return results_stack[0]

op = None

def __init__(self, state1, state2=None):
super(CompositeSubsetState, self).__init__()
self.state1 = state1.copy()
if state1:
state1 = state1.copy()
self.state1 = state1
if state2:
state2 = state2.copy()
self.state2 = state2

def copy(self):
return type(self)(self.state1, self.state2)
leaf_func = lambda state: state.copy()

def _copy_composite(state, lhs, rhs=None):
copy = type(state)(None, None)
copy.state1 = lhs
copy.state2 = rhs
return copy

return self._traverse_state_tree(leaf_func, _copy_composite, _copy_composite)

@property
def attributes(self):
att = self.state1.attributes
if self.state2 is not None:
att += self.state2.attributes
return tuple(sorted(set(att)))
leaf_func = lambda state: state.attributes
invert_func = lambda _, input: input
binary_func = lambda _, lhs, rhs: lhs + rhs
preclean = self._traverse_state_tree(leaf_func, invert_func, binary_func)
return tuple(sorted(set(preclean)))

@memoize
@contract(data='isinstance(Data)', view='array_view')
def to_mask(self, data, view=None):
return self.op(self.state1.to_mask(data, view),
self.state2.to_mask(data, view))
leaf_func = lambda state, data=data, view=view: state.to_mask(data, view)
invert_func = lambda _, input: ~input
binary_func = lambda state, lhs, rhs: state.op(lhs, rhs)
return self._traverse_state_tree(leaf_func, invert_func, binary_func)

def __str__(self):
sym = OPSYM.get(self.op, self.op)
return "(%s %s %s)" % (self.state1, sym, self.state2)
leaf_func = lambda state: "%s" % state
invert_func = lambda _, input: "(~%s)" % input
binary_func = lambda state, lhs, rhs: "(%s %s %s)" % (lhs, OPSYM.get(state.op, state.op), rhs)
return self._traverse_state_tree(leaf_func, invert_func, binary_func)


class OrState(CompositeSubsetState):
Expand Down Expand Up @@ -1070,14 +1131,7 @@ class InvertState(CompositeSubsetState):
vice-versa. The original subset state can be accessed using the attribute
``state1``.
"""

@memoize
@contract(data='isinstance(Data)', view='array_view')
def to_mask(self, data, view=None):
return ~self.state1.to_mask(data, view)

def __str__(self):
return "(~%s)" % self.state1
op = operator.invert


class MultiOrState(SubsetState):
Expand Down
47 changes: 46 additions & 1 deletion glue/core/tests/test_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,30 @@ def test_xor(self):
assert isinstance(s3, XorState)


class WatchedOrState(OrState):

str_calls = 0
tomask_calls = 0
copy_calls = 0

def __str__(self):
WatchedOrState.str_calls += 1
return super(WatchedOrState, self).__str__()

def copy(self):
WatchedOrState.copy_calls += 1
return super(WatchedOrState, self).copy()

def to_mask(self, data, view=None):
WatchedOrState.tomask_calls += 1
return super(WatchedOrState, self).to_mask(data, view)

def reset_counts():
WatchedOrState.str_calls = 0
WatchedOrState.tomask_calls = 0
WatchedOrState.copy_calls = 0


class TestCompositeSubsetStates(object):

class DummyState(SubsetState):
Expand Down Expand Up @@ -284,6 +308,22 @@ def test_multicomposite(self):
expected = np.array([False, True, False, False])
np.testing.assert_array_equal(answer, expected)

def test_not_recursion(self):
state = WatchedOrState(self.sub1, self.sub2)
for i in range(10):
state = WatchedOrState(state, WatchedOrState(self.sub2, self.sub1))

WatchedOrState.reset_counts()

mask = state.to_mask(self.data)
assert WatchedOrState.tomask_calls == 1

string = str(state)
assert WatchedOrState.str_calls == 1

copy = state.copy()
assert WatchedOrState.copy_calls == 1


class TestElementSubsetState(object):

Expand Down Expand Up @@ -409,7 +449,12 @@ def assert_composite_copy(self, cls):
assert s1.state2.copy() is s2.state2

def test_invert(self):
self.assert_composite_copy(InvertState)
state1 = MagicMock()
s1 = InvertState(state1)
s2 = s1.copy()

assert type(s1) == type(s2)
assert s1.state1.copy() is s2.state1

def test_and(self):
self.assert_composite_copy(AndState)
Expand Down