diff --git a/iguala/helpers.py b/iguala/helpers.py index 785ea11..c3fad82 100644 --- a/iguala/helpers.py +++ b/iguala/helpers.py @@ -3,10 +3,10 @@ class match(object): - def __init__(self, cls): + def __init__(self, cls, name=None): from .matchers import ObjectMatcher - self.matcher = ObjectMatcher(cls) + self.matcher = ObjectMatcher(cls, name=name) self.matcher.properties = {} def __mod__(self, properties): diff --git a/iguala/matchers.py b/iguala/matchers.py index 96338c2..fa404d6 100644 --- a/iguala/matchers.py +++ b/iguala/matchers.py @@ -50,6 +50,7 @@ def __init__(self, truth=True): self._is_match = truth self.truth = truth self.delayed_matchers = [] + self.known_subpatterns = {} def __getitem__(self, key): return self.bindings[key] @@ -85,6 +86,7 @@ def copy(self): instance = self.__class__(self.truth) instance.bindings.update(self.bindings) instance.delayed_matchers.extend(self.delayed_matchers) + instance.known_subpatterns.update(self.known_subpatterns) return instance @@ -154,7 +156,15 @@ def __rmatmul__(self, other): return self def match_context(self, obj, context): - context[self.alias] = obj + if self.alias in context: + res = context[self.alias] + if isinstance(res, list): + context[self.alias].extend(res) + else: + context[self.alias] = [res] + context[self.alias].append(obj) + else: + context[self.alias] = obj return self.matcher.match_context(obj, context) @@ -234,12 +244,15 @@ def match_context(self, obj, context): class ObjectMatcher(KeyValueMatcher, Matcher): - def __init__(self, cls, properties=None, subclassmatch=False): + def __init__(self, cls, properties=None, subclassmatch=False, name=None): self.properties = properties self.cls = cls self.subclassmatch = subclassmatch + self.name = name def match_context(self, obj, context): + if self.name: + context.known_subpatterns[self.name] = self sametype = ( isinstance(obj, self.cls) if self.subclassmatch @@ -379,6 +392,15 @@ def is_anonymous(self): return not self.alias or super().is_anonymous +class RecursiveMatcherReference(Matcher): + def __init__(self, name): + self.name = name + + def match_context(self, obj, context): + subpattern = context.known_subpatterns[self.name] + return subpattern.match_context(obj, context) + + class SequenceMatcher(Matcher): def __init__(self, sequence): self.sequence = [as_matcher(m) for m in sequence] @@ -600,4 +622,5 @@ def as_matcher(obj): cond = ConditionalMatcher regex = RegexMatcher is_ = IdentityMatcher +rec = RecursiveMatcherReference save_as = lambda alias: SaveNodeMatcher(alias, None) \ No newline at end of file