Skip to content

Commit

Permalink
✨ _SpecialPattern
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 14, 2024
1 parent 09c45d1 commit e3bb5d7
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
46 changes: 46 additions & 0 deletions nepattern/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,32 @@
_T = TypeVar("_T")
_T1 = TypeVar("_T1")

_TP = TypeVar("_TP", bound=Pattern)


def _SpecialPattern(cls: type[_TP]) -> type[_TP]:
old_init = cls.__init__

def __init__(self, *args, **kwargs):
old_init(self, *args, **kwargs)

@self.pre_validate
def _(x):
try:
self.match(x)
return True
except MatchFailed:
return False

@self.convert
def _(s, x):
return s.match(x)

cls.__init__ = __init__
return cls


@_SpecialPattern
class DirectPattern(Pattern[TOrigin]):
"""直接判断"""

Expand All @@ -40,9 +65,14 @@ def copy(self):
return DirectPattern(self.target, self.alias)


@_SpecialPattern
class DirectTypePattern(Pattern[TOrigin]):
"""直接类型判断"""

def __init__(self, origin: type[TOrigin], alias: str | None = None):
self.origin = origin
super().__init__(origin, alias)

def match(self, input_: Any):
if not isinstance(input_, self.origin):
raise MatchFailed(
Expand All @@ -59,6 +89,7 @@ def copy(self):
return DirectTypePattern(self.origin, self.alias)


@_SpecialPattern
class RegexPattern(Pattern[Match[str]]):
"""针对正则的特化匹配,支持正则组"""

Expand Down Expand Up @@ -87,6 +118,7 @@ def copy(self):
return RegexPattern(self.pattern, self.alias)


@_SpecialPattern
class UnionPattern(Pattern[_T]):
"""多类型参数的匹配"""

Expand Down Expand Up @@ -143,6 +175,7 @@ def __eq__(self, other): # pragma: no cover
_TSwtich = TypeVar("_TSwtich")


@_SpecialPattern
class SwitchPattern(Pattern[_TCase], Generic[_TCase, _TSwtich]):
"""匹配多种情况的表达式"""

Expand Down Expand Up @@ -171,6 +204,7 @@ def __eq__(self, other): # pragma: no cover
return isinstance(other, SwitchPattern) and self.switch == other.switch


@_SpecialPattern
class ForwardRefPattern(Pattern[Any]):
def __init__(self, ref: ForwardRef):
self.ref = ref
Expand All @@ -196,6 +230,7 @@ def __eq__(self, other): # pragma: no cover
return isinstance(other, ForwardRefPattern) and self.ref == other.ref


@_SpecialPattern
class AntiPattern(Pattern[TOrigin]):
def __init__(self, pattern: Pattern[TOrigin]):
self.base: Pattern[TOrigin] = pattern
Expand All @@ -220,6 +255,7 @@ def __eq__(self, other): # pragma: no cover


@final
@_SpecialPattern
class AnyStrPattern(Pattern[str]):
def __init__(self):
super().__init__(origin=str, alias="any_str")
Expand All @@ -236,6 +272,7 @@ def __eq__(self, other): # pragma: no cover


@final
@_SpecialPattern
class StrPattern(Pattern[str]):
def __init__(self):
super().__init__(origin=str, alias="str")
Expand All @@ -259,6 +296,7 @@ def __eq__(self, other): # pragma: no cover


@final
@_SpecialPattern
class BytesPattern(Pattern[bytes]):
def __init__(self):
super().__init__(origin=bytes, alias="bytes")
Expand All @@ -284,6 +322,7 @@ def __eq__(self, other): # pragma: no cover


@final
@_SpecialPattern
class IntPattern(Pattern[int]):
def __init__(self):
super().__init__(origin=int, alias="int")
Expand All @@ -309,6 +348,7 @@ def __eq__(self, other): # pragma: no cover


@final
@_SpecialPattern
class FloatPattern(Pattern[float]):
def __init__(self):
super().__init__(origin=float, alias="float")
Expand All @@ -332,6 +372,7 @@ def __eq__(self, other): # pragma: no cover


@final
@_SpecialPattern
class NumberPattern(Pattern[Union[int, float]]):
def __init__(self):
super().__init__(origin=Union[int, float], alias="number") # type: ignore
Expand All @@ -356,6 +397,7 @@ def __eq__(self, other): # pragma: no cover


@final
@_SpecialPattern
class BoolPattern(Pattern[bool]):
def __init__(self):
super().__init__(origin=bool, alias="bool")
Expand All @@ -382,6 +424,7 @@ def __eq__(self, other): # pragma: no cover


@final
@_SpecialPattern
class WideBoolPattern(Pattern[bool]):
def __init__(self):
super().__init__(origin=bool, alias="bool")
Expand Down Expand Up @@ -440,6 +483,7 @@ def __eq__(self, other): # pragma: no cover


@final
@_SpecialPattern
class HexPattern(Pattern[int]):
def __init__(self):
super().__init__(origin=int, alias="hex")
Expand Down Expand Up @@ -470,6 +514,7 @@ def __eq__(self, other): # pragma: no cover


@final
@_SpecialPattern
class DateTimePattern(Pattern[datetime]):
def __init__(self):
super().__init__(origin=datetime, alias="datetime")
Expand All @@ -494,6 +539,7 @@ def __eq__(self, other): # pragma: no cover


@final
@_SpecialPattern
class PathPattern(Pattern[Path]):
def __init__(self):
super().__init__(origin=Path, alias="path")
Expand Down
13 changes: 9 additions & 4 deletions nepattern/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tarina.lang import lang

from .exception import MatchFailed
from .util import TPattern

T = TypeVar("T")

Expand Down Expand Up @@ -57,8 +58,8 @@ def __repr__(self):

class Pattern(Generic[T]):
@staticmethod
def regex_match(pattern: str, alias: str | None = None):
pat = Pattern(str, alias)
def regex_match(pattern: str | TPattern, alias: str | None = None):
pat = Pattern(str, alias or str(pattern))

@pat.convert
def _(self, x: str):
Expand All @@ -79,7 +80,7 @@ def regex_convert(
alias: str | None = None,
allow_origin: bool = False,
):
pat = Pattern(origin, alias)
pat = Pattern(origin, alias or str(pattern))
if allow_origin:
pat.accept(Union[str, origin])

Expand Down Expand Up @@ -145,7 +146,7 @@ def post_validate(self, func: Callable[[T], bool]):
self._post_validator = func
return self

def convert(self, func: Callable[[Self, Any], T]):
def convert(self, func: Callable[[Self, Any], T | None]):
self._converter = func
return self

Expand All @@ -160,6 +161,10 @@ def match(self, input_: Any) -> T:
)
if self._converter:
input_ = self._converter(self, input_)
if input_ is None:
raise MatchFailed(
lang.require("nepattern", "error.content").format(target=input_, expected=self.origin)
)
if self._post_validator and not self._post_validator(input_):
raise MatchFailed(
lang.require("nepattern", "error.content").format(target=input_, expected=self.origin)
Expand Down

0 comments on commit e3bb5d7

Please sign in to comment.