From b6540a619841913c527aec7beb4c9b74df65ba61 Mon Sep 17 00:00:00 2001 From: aranega Date: Sat, 2 Nov 2024 20:03:51 -0600 Subject: [PATCH] Fix issue with matcher generators generating collection matchers --- iguala/matchers.py | 40 ++++++++++++++-- tests/test_object_matchers.py | 2 +- tests/test_sequence_matchers.py | 84 +++++++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 5 deletions(-) diff --git a/iguala/matchers.py b/iguala/matchers.py index 10dd8ae..1ba7dad 100644 --- a/iguala/matchers.py +++ b/iguala/matchers.py @@ -83,6 +83,9 @@ def __iter__(self): def __len__(self): return len(self.bindings) + def replace(self, alias, value): + self.bindings[alias] = value + @property def is_match(self): return self._is_match is self.truth @@ -103,7 +106,11 @@ def as_matcher(self): return self @property - def is_collection_matcher(self): + def is_collection_matcher(self) -> bool: + return False + + @property + def has_matcher_generator(self) -> bool: return False @property @@ -151,6 +158,10 @@ def __init__(self, alias, matcher): self.alias = alias self.matcher = matcher + @property + def has_matcher_generator(self): + return self.matcher.has_matcher_generator + @property def is_collection_matcher(self): return self.matcher.is_collection_matcher @@ -160,6 +171,8 @@ def is_list_wildcard(self): return self.matcher.is_list_wildcard def match_context(self, obj, context): + if self.has_matcher_generator: + self.matcher.delay_alias_binding(obj, self.alias) context[self.alias] = obj return self.matcher.match_context(obj, context) @@ -222,7 +235,7 @@ def match_context(self, obj, context): for path, matcher in self.properties: results = [] for context in new_contexts: - if matcher.is_collection_matcher: + if matcher.is_collection_matcher or matcher.has_matcher_generator: cpy = context.copy() results.extend(matcher.match_context(path.resolve_from(obj), cpy)) else: @@ -288,6 +301,10 @@ def __init__(self, fun): self.vars.remove(self.__self__) else: self.has_self = False + self.delayed_alias_bindings = {} + + def delay_alias_binding(self, value, alias): + self.delayed_alias_bindings[alias] = value def match_context(self, obj, context): try: @@ -301,8 +318,23 @@ def match_context(self, obj, context): class MatcherGenerator(LambdaBasedMatcher): + @property + def has_matcher_generator(self): + return True + def execute(self, obj, context, kwargs): - return as_matcher(self.fun(**kwargs)).match_context(obj, context) + resolved = as_matcher(self.fun(**kwargs)) + + for alias, o in self.delayed_alias_bindings.items(): + if not resolved.is_collection_matcher: + context.replace(alias, o[0]) + + if resolved.is_collection_matcher: + cpy = context.copy() + results = resolved.match_context(obj, cpy) + else: + results = flat([resolved.match_context(o, context.copy()) for o in obj]) + return results class ConditionalMatcher(LambdaBasedMatcher): @@ -321,7 +353,7 @@ def can_execute(self, context): return all(x in context for x in self.matcher.vars) def execute(self, context): - return self.matcher.match_context(self.self_object, context.copy()) + return self.matcher.match_context(self.self_object, context) class RegexMatcher(Matcher): diff --git a/tests/test_object_matchers.py b/tests/test_object_matchers.py index 2e37bbb..cff46b4 100644 --- a/tests/test_object_matchers.py +++ b/tests/test_object_matchers.py @@ -153,4 +153,4 @@ def test_match_union_type(): assert pattern.match(BTest(0, 0, 'foo', InnerTest('bar', 5), [], 5)).is_match is False pattern = match(ATest | BTest) % {} - assert pattern.match(BTest(0, 0, 'foo', InnerTest('bar', 5), [], 5)).is_match is True \ No newline at end of file + assert pattern.match(BTest(0, 0, 'foo', InnerTest('bar', 5), [], 5)).is_match is True diff --git a/tests/test_sequence_matchers.py b/tests/test_sequence_matchers.py index 8f1edb4..22cebd3 100644 --- a/tests/test_sequence_matchers.py +++ b/tests/test_sequence_matchers.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import pytest from iguala import as_matcher, is_not, match @@ -102,3 +103,86 @@ def test_ellipsis_list(pattern, data, expected, variables): for ctx, bindings in zip(variables, result.bindings): for var, val in ctx.items(): assert bindings[var] == val + + +def test_sequence_matcher_generator(): + @dataclass + class A: + x: int + y: int | str + tab: list[int | str] + + def i_element_after(i, x, var): + return as_matcher([..., x, *["@_"]*i, f"@{var}", ...]) + + p = match(A)[ + "x": "@step", + "y": "@y", + "tab": lambda step, y: i_element_after(step, y, "res") + ] + + a = A(x=2, y="r", tab=[3, 4, "r", 3, 3, 5, 6, "r", 1, 2, 8]) + + res = p.match(a)["res"] + + assert len(res) == 2 + assert res == [5, 8] + + +def test_matcher_generator_iterate_list(): + @dataclass + class A: + x: int + tab: list[int | str] + + a = A(x=3, tab=[3, 4, "r", 3, 3, 5, 6, "r", 1, 2, 8]) + + p = match(A)[ + "x": "@x", + "tab": lambda x: x + ] + + res = p.match(a)["x"] + + assert len(res) == 3 + + + +def test_matcher_generator_save_nodes(): + @dataclass + class A: + x: int + tab: list[int | str] + + a = A(x=3, tab=[3, 4, "r", 3, 3, 5, 6, "r", 1, 2, 8]) + + p = match(A)[ + "x": "@x", + "tab": as_matcher(lambda x: x) @ "xx" + ] + + res = p.match(a) + + assert len(res["x"]) == 3 + assert len(res["xx"]) == 3 + assert res["xx"] == [3, 3, 3] + + +# def test_matcher_generator_inner_save_nodes(): +# @dataclass +# class A: +# x: int +# tab: list[int | str] + +# a = A(x=3, tab=[3, 4, "r", 3, 3, 5, 6, "r", 1, 2, 8]) + +# p = match(A)[ +# "x": "@x", +# "tab": as_matcher(lambda x: as_matcher(lambda z: z) @ "z") @ "xx" +# ] + +# res = p.match(a) + +# assert len(res["x"]) == 3 +# assert len(res["xx"]) == 3 +# assert res["xx"] == [3, 3, 3] \ No newline at end of file