Skip to content

Commit

Permalink
✨ _RegexPattern
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 14, 2024
1 parent e3bb5d7 commit e2abfa4
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 29 deletions.
21 changes: 10 additions & 11 deletions nepattern/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from tarina import DateParser, lang

from .core import Pattern
from .core import Pattern, _RegexPattern
from .exception import MatchFailed
from .util import TPattern

Expand All @@ -28,15 +28,15 @@ def __init__(self, *args, **kwargs):
old_init(self, *args, **kwargs)

@self.pre_validate
def _(x):
def _(x): # pragma: no cover
try:
self.match(x)
return True
except MatchFailed:
return False

@self.convert
def _(s, x):
def _(s, x): # pragma: no cover
return s.match(x)

cls.__init__ = __init__
Expand All @@ -61,7 +61,7 @@ def match(self, input_: Any):
def __eq__(self, other): # pragma: no cover
return isinstance(other, DirectPattern) and self.target == other.target

def copy(self):
def copy(self): # pragma: no cover
return DirectPattern(self.target, self.alias)


Expand All @@ -85,18 +85,17 @@ def match(self, input_: Any):
def __eq__(self, other): # pragma: no cover
return isinstance(other, DirectTypePattern) and self.origin is other.origin

def copy(self):
def copy(self): # pragma: no cover
return DirectTypePattern(self.origin, self.alias)


@_SpecialPattern
class RegexPattern(Pattern[Match[str]]):
class RegexPattern(_RegexPattern[Match[str]]):
"""针对正则的特化匹配,支持正则组"""

def __init__(self, pattern: str | TPattern, alias: str | None = None):
super().__init__(Match[str], alias=alias or "regex[:group]")
super().__init__(pattern, Match[str], alias=alias or "regex[:group]")
self.regex_pattern = re.compile(pattern)
self.pattern = self.regex_pattern.pattern

def match(self, input_: Any) -> Match[str]:
if not isinstance(input_, str):
Expand All @@ -114,7 +113,7 @@ def match(self, input_: Any) -> Match[str]:
def __eq__(self, other): # pragma: no cover
return isinstance(other, RegexPattern) and self.pattern == other.pattern

def copy(self):
def copy(self): # pragma: no cover
return RegexPattern(self.pattern, self.alias)


Expand Down Expand Up @@ -304,7 +303,7 @@ def __init__(self):
def match(self, input_: Any) -> bytes:
if isinstance(input_, bytes):
return input_
elif isinstance(input_, bytearray):
elif isinstance(input_, bytearray): # pragma: no cover
return bytes(input_)
elif isinstance(input_, str):
return input_.encode()
Expand Down Expand Up @@ -407,7 +406,7 @@ def match(self, input_: Any) -> bool:
return input_
if isinstance(input_, bytes): # pragma: no cover
input_ = input_.decode()
if isinstance(input_, str):
if isinstance(input_, str): # pragma: no cover
input_ = input_.lower()
if input_ == "true":
return True
Expand Down
2 changes: 1 addition & 1 deletion nepattern/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def remove(self, origin_type, alias=None):
else:
del self.data[alias]
elif al_pat := self.data.get(origin_type):
if isinstance(al_pat, UnionPattern):
if isinstance(al_pat, UnionPattern): # pragma: no cover
self.data[origin_type] = UnionPattern(
*filter(lambda x: x.origin != origin_type, al_pat.for_validate)
)
Expand Down
14 changes: 10 additions & 4 deletions nepattern/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def success(self) -> bool:
def failed(self) -> bool:
return self._value is Empty

def __bool__(self):
def __bool__(self): # pragma: no cover
return self._value != Empty

def __repr__(self):
Expand All @@ -59,7 +59,7 @@ def __repr__(self):
class Pattern(Generic[T]):
@staticmethod
def regex_match(pattern: str | TPattern, alias: str | None = None):
pat = Pattern(str, alias or str(pattern))
pat = _RegexPattern(pattern, str, alias or str(pattern))

@pat.convert
def _(self, x: str):
Expand All @@ -80,7 +80,7 @@ def regex_convert(
alias: str | None = None,
allow_origin: bool = False,
):
pat = Pattern(origin, alias or str(pattern))
pat = _RegexPattern(pattern, origin, alias or str(pattern))
if allow_origin:
pat.accept(Union[str, origin])

Expand Down Expand Up @@ -190,7 +190,7 @@ def __repr__(self):
def copy(self) -> Self:
return deepcopy(self)

def __rrshift__(self, other):
def __rrshift__(self, other): # pragma: no cover
return self.execute(other)

def __rmatmul__(self, other) -> Self: # pragma: no cover
Expand All @@ -208,3 +208,9 @@ def __hash__(self):

def __eq__(self, other):
return isinstance(other, Pattern) and self.__hash__() == other.__hash__()


class _RegexPattern(Pattern[T]):
def __init__(self, pattern: str | TPattern, origin: type[T], alias: str | None = None):
super().__init__(origin, alias)
self.pattern = pattern
53 changes: 40 additions & 13 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ def test_basic():
from datetime import datetime

res = STRING.execute("123")
if res:
assert res.success
assert res.value() == "123"
assert res.success
assert res.value() == "123"
assert STRING.execute(b"123").value() == "123"
assert STRING.execute(123).failed

Expand Down Expand Up @@ -75,6 +74,10 @@ def test_basic():
assert PATH.execute(Path("a/b/c")).value() == Path("a/b/c")
assert PATH.execute([]).failed

assert DelimiterInt.execute("1,000").value() == 1000
assert DelimiterInt.execute("1,000,000").value() == 1000000
assert DelimiterInt.execute("1,000,000.0").failed


def test_result():
res = NUMBER.execute(123)
Expand All @@ -96,7 +99,7 @@ def test_result():


def test_pattern_of():
"""测试 BasePattern 的快速创建方法之一, 对类有效"""
"""测试 Pattern 的快速创建方法之一, 对类有效"""
pat = Pattern(int)
assert pat.origin == int
assert pat.execute(123).value() == 123
Expand All @@ -107,7 +110,7 @@ def test_pattern_of():


def test_pattern_on():
"""测试 BasePattern 的快速创建方法之一, 对对象有效"""
"""测试 Pattern 的快速创建方法之一, 对对象有效"""
pat1 = Pattern.on(123)
assert pat1.origin == int
assert pat1.execute(123).value() == 123
Expand All @@ -116,23 +119,23 @@ def test_pattern_on():


def test_pattern_keep():
"""测试 BasePattern 的保持模式, 不会进行匹配或者类型转换"""
"""测试 Pattern 的保持模式, 不会进行匹配或者类型转换"""
pat2 = Pattern()
assert pat2.execute(123).value() == 123
assert pat2.execute("abc").value() == "abc"
print(pat2)


def test_pattern_regex():
"""测试 BasePattern 的正则匹配模式, 仅正则匹配"""
"""测试 Pattern 的正则匹配模式, 仅正则匹配"""
pat3 = Pattern.regex_match("abc[A-Z]+123")
assert pat3.execute("abcABC123").value() == "abcABC123"
assert pat3.execute("abcAbc123").failed
print(pat3)


def test_pattern_regex_convert():
"""测试 BasePattern 的正则转换模式, 正则匹配成功后再进行类型转换"""
"""测试 Pattern 的正则转换模式, 正则匹配成功后再进行类型转换"""
pat4 = Pattern.regex_convert(
r"\[at:(\d+)\]", int, lambda m: res if (res := int(m[1])) < 1000000 else None, allow_origin=True
)
Expand All @@ -142,9 +145,15 @@ def test_pattern_regex_convert():
assert pat4.execute("[at:1234567]").failed
print(pat4)

pat4_1 = Pattern.regex_convert(r"\[at:(\d+)\]", int, lambda m: int(m[1]), allow_origin=False)
assert pat4_1.execute("[at:123456]").value() == 123456
assert pat4_1.execute("[at:abcdef]").failed
assert pat4_1.execute(123456).failed
print(pat4_1)


def test_pattern_type_convert():
"""测试 BasePattern 的类型转换模式, 仅将传入对象变为另一类型的新对象"""
"""测试 Pattern 的类型转换模式, 仅将传入对象变为另一类型的新对象"""
pat5 = Pattern(origin=str).convert(lambda _, x: str(x))
assert pat5.execute(123).value() == "123"
assert pat5.execute([4, 5, 6]).value() == "[4, 5, 6]"
Expand All @@ -167,7 +176,7 @@ def convert(self, content):


def test_pattern_accepts():
"""测试 BasePattern 的输入类型筛选, 不在范围内的类型视为非法"""
"""测试 Pattern 的输入类型筛选, 不在范围内的类型视为非法"""

pat6 = Pattern(str).accept(bytes).convert(lambda _, x: x.decode())
assert pat6.execute(b"123").value() == "123"
Expand All @@ -178,8 +187,16 @@ def test_pattern_accepts():
print(pat6, pat6_1)


def test_pattern_pre_validator():
"""测试 Pattern 的匹配前验证器, 会在匹配前对输入进行验证"""
pat7 = Pattern(float).pre_validate(lambda x: x != 0).convert(lambda _, x: 1 / x)
assert pat7.execute(123).value() == 1 / 123
assert pat7.execute(0).failed
print(pat7)


def test_pattern_anti():
"""测试 BasePattern 的反向验证功能"""
"""测试 Pattern 的反向验证功能"""
pat8 = Pattern(int)
pat8_1 = AntiPattern(pat8)
assert pat8.execute(123).value() == 123
Expand All @@ -189,7 +206,7 @@ def test_pattern_anti():


def test_pattern_validator():
"""测试 BasePattern 的匹配后验证器, 会对匹配结果进行验证"""
"""测试 Pattern 的匹配后验证器, 会对匹配结果进行验证"""
pat9 = Pattern(int).accept(int).post_validate(lambda x: x > 0)
assert pat9.execute(23).value() == 23
assert pat9.execute(-23).failed
Expand Down Expand Up @@ -268,6 +285,16 @@ def test_union_pattern():
print(pat12, pat12_1, pat12_2)
pat12_3 = UnionPattern.of(List[bool], int)
print(pat12_3)
pat12_4 = UnionPattern(INTEGER, WIDE_BOOLEAN)
assert pat12_4.execute(123).success
assert pat12_4.execute("123").success
assert pat12_4.execute("123").value() == 123
assert pat12_4.execute("true").success
assert pat12_4.execute("true").value() is True
assert pat12_4.execute("false").success
assert pat12_4.execute("false").value() is False
assert pat12_4.execute("yes").success
assert pat12_4.execute("yes").value() is True


def test_converters():
Expand Down Expand Up @@ -329,7 +356,7 @@ def test_regex_pattern():
assert res.groupdict() == {"owner": "ArcletProject", "repo": "NEPattern"}
assert pat18.execute(123).failed
assert pat18.execute("www.bilibili.com").failed
pat18_1 = parser(r"re:(\d+)") # str starts with "re:" will convert to BasePattern instead of RegexPattern
pat18_1 = parser(r"re:(\d+)") # str starts with "re:" will convert to Pattern instead of RegexPattern
assert pat18_1.execute("1234").value() == "1234"
pat18_2 = parser(r"rep:(\d+)") # str starts with "rep:" will convert to RegexPattern
assert pat18_2.execute("1234").value().groups() == ("1234",) # type: ignore
Expand Down

0 comments on commit e2abfa4

Please sign in to comment.