Skip to content

Commit

Permalink
Fix issue with matcher generators generating collection matchers
Browse files Browse the repository at this point in the history
  • Loading branch information
aranega committed Nov 3, 2024
1 parent 2e9824c commit b6540a6
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 5 deletions.
40 changes: 36 additions & 4 deletions iguala/matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def __iter__(self):
def __len__(self):
return len(self.bindings)

def replace(self, alias, value):
self.bindings[alias] = value

@property
def is_match(self):
return self._is_match is self.truth
Expand All @@ -103,7 +106,11 @@ def as_matcher(self):
return self

@property
def is_collection_matcher(self):
def is_collection_matcher(self) -> bool:
return False

@property
def has_matcher_generator(self) -> bool:
return False

@property
Expand Down Expand Up @@ -151,6 +158,10 @@ def __init__(self, alias, matcher):
self.alias = alias
self.matcher = matcher

@property
def has_matcher_generator(self):
return self.matcher.has_matcher_generator

@property
def is_collection_matcher(self):
return self.matcher.is_collection_matcher
Expand All @@ -160,6 +171,8 @@ def is_list_wildcard(self):
return self.matcher.is_list_wildcard

def match_context(self, obj, context):
if self.has_matcher_generator:
self.matcher.delay_alias_binding(obj, self.alias)
context[self.alias] = obj
return self.matcher.match_context(obj, context)

Expand Down Expand Up @@ -222,7 +235,7 @@ def match_context(self, obj, context):
for path, matcher in self.properties:
results = []
for context in new_contexts:
if matcher.is_collection_matcher:
if matcher.is_collection_matcher or matcher.has_matcher_generator:
cpy = context.copy()
results.extend(matcher.match_context(path.resolve_from(obj), cpy))
else:
Expand Down Expand Up @@ -288,6 +301,10 @@ def __init__(self, fun):
self.vars.remove(self.__self__)
else:
self.has_self = False
self.delayed_alias_bindings = {}

def delay_alias_binding(self, value, alias):
self.delayed_alias_bindings[alias] = value

def match_context(self, obj, context):
try:
Expand All @@ -301,8 +318,23 @@ def match_context(self, obj, context):


class MatcherGenerator(LambdaBasedMatcher):
@property
def has_matcher_generator(self):
return True

def execute(self, obj, context, kwargs):
return as_matcher(self.fun(**kwargs)).match_context(obj, context)
resolved = as_matcher(self.fun(**kwargs))

for alias, o in self.delayed_alias_bindings.items():
if not resolved.is_collection_matcher:
context.replace(alias, o[0])

if resolved.is_collection_matcher:
cpy = context.copy()
results = resolved.match_context(obj, cpy)
else:
results = flat([resolved.match_context(o, context.copy()) for o in obj])
return results


class ConditionalMatcher(LambdaBasedMatcher):
Expand All @@ -321,7 +353,7 @@ def can_execute(self, context):
return all(x in context for x in self.matcher.vars)

def execute(self, context):
return self.matcher.match_context(self.self_object, context.copy())
return self.matcher.match_context(self.self_object, context)


class RegexMatcher(Matcher):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_object_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,4 @@ def test_match_union_type():
assert pattern.match(BTest(0, 0, 'foo', InnerTest('bar', 5), [], 5)).is_match is False

pattern = match(ATest | BTest) % {}
assert pattern.match(BTest(0, 0, 'foo', InnerTest('bar', 5), [], 5)).is_match is True
assert pattern.match(BTest(0, 0, 'foo', InnerTest('bar', 5), [], 5)).is_match is True
84 changes: 84 additions & 0 deletions tests/test_sequence_matchers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
import pytest

from iguala import as_matcher, is_not, match
Expand Down Expand Up @@ -102,3 +103,86 @@ def test_ellipsis_list(pattern, data, expected, variables):
for ctx, bindings in zip(variables, result.bindings):
for var, val in ctx.items():
assert bindings[var] == val


def test_sequence_matcher_generator():
@dataclass
class A:
x: int
y: int | str
tab: list[int | str]

def i_element_after(i, x, var):
return as_matcher([..., x, *["@_"]*i, f"@{var}", ...])

p = match(A)[
"x": "@step",
"y": "@y",
"tab": lambda step, y: i_element_after(step, y, "res")
]

a = A(x=2, y="r", tab=[3, 4, "r", 3, 3, 5, 6, "r", 1, 2, 8])

res = p.match(a)["res"]

assert len(res) == 2
assert res == [5, 8]


def test_matcher_generator_iterate_list():
@dataclass
class A:
x: int
tab: list[int | str]

a = A(x=3, tab=[3, 4, "r", 3, 3, 5, 6, "r", 1, 2, 8])

p = match(A)[
"x": "@x",
"tab": lambda x: x
]

res = p.match(a)["x"]

assert len(res) == 3



def test_matcher_generator_save_nodes():
@dataclass
class A:
x: int
tab: list[int | str]

a = A(x=3, tab=[3, 4, "r", 3, 3, 5, 6, "r", 1, 2, 8])

p = match(A)[
"x": "@x",
"tab": as_matcher(lambda x: x) @ "xx"
]

res = p.match(a)

assert len(res["x"]) == 3
assert len(res["xx"]) == 3
assert res["xx"] == [3, 3, 3]


# def test_matcher_generator_inner_save_nodes():
# @dataclass
# class A:
# x: int
# tab: list[int | str]

# a = A(x=3, tab=[3, 4, "r", 3, 3, 5, 6, "r", 1, 2, 8])

# p = match(A)[
# "x": "@x",
# "tab": as_matcher(lambda x: as_matcher(lambda z: z) @ "z") @ "xx"
# ]

# res = p.match(a)

# assert len(res["x"]) == 3
# assert len(res["xx"]) == 3
# assert res["xx"] == [3, 3, 3]

0 comments on commit b6540a6

Please sign in to comment.