diff --git a/marshmallow_recipe/__init__.py b/marshmallow_recipe/__init__.py index 3bbc053..cd17708 100644 --- a/marshmallow_recipe/__init__.py +++ b/marshmallow_recipe/__init__.py @@ -30,7 +30,7 @@ from .naming_case import CAMEL_CASE, CAPITAL_CAMEL_CASE, CamelCase, CapitalCamelCase, NamingCase from .options import NoneValueHandling, options from .serialization import EmptySchema, dump, dump_many, load, load_many, schema -from .validator import ValidationError, ValidationFunc +from .validation import ValidationFunc, regexp_validate, validate __all__: tuple[str, ...] = ( # bake.py @@ -85,9 +85,10 @@ "dump_many", "schema", "EmptySchema", - # validator.py - "ValidationError", + # validation.py "ValidationFunc", + "regexp_validate", + "validate", ) __version__ = "0.0.31" diff --git a/marshmallow_recipe/fields.py b/marshmallow_recipe/fields.py index ec0ae05..5730ac0 100644 --- a/marshmallow_recipe/fields.py +++ b/marshmallow_recipe/fields.py @@ -8,7 +8,7 @@ import marshmallow as m import marshmallow.validate -from .validator import ValidationFunc +from .validation import ValidationFunc _MARSHMALLOW_VERSION_MAJOR = int(m.__version__.split(".")[0]) diff --git a/marshmallow_recipe/metadata.py b/marshmallow_recipe/metadata.py index 4d26e9e..386ee56 100644 --- a/marshmallow_recipe/metadata.py +++ b/marshmallow_recipe/metadata.py @@ -2,7 +2,7 @@ from typing import Any, TypeGuard, final from .missing import MISSING -from .validator import ValidationFunc +from .validation import ValidationFunc @final diff --git a/marshmallow_recipe/validation.py b/marshmallow_recipe/validation.py new file mode 100644 index 0000000..19f3cbd --- /dev/null +++ b/marshmallow_recipe/validation.py @@ -0,0 +1,24 @@ +import collections.abc +import re +from typing import Any + +import marshmallow.validate + +ValidationFunc = collections.abc.Callable[[Any], Any] + + +def regexp_validate(regexp: re.Pattern | str, *, error: str | None = None) -> ValidationFunc: + return marshmallow.validate.Regexp(regexp, error=error) + + +def validate(validator: ValidationFunc, *, error: str | None = None) -> ValidationFunc: + if error is None: + return validator + + def _validator_with_custom_error(value: Any) -> Any: + result = validator(value) + if result is False: + raise marshmallow.ValidationError(error) + return result + + return _validator_with_custom_error diff --git a/marshmallow_recipe/validator.py b/marshmallow_recipe/validator.py deleted file mode 100644 index 969794f..0000000 --- a/marshmallow_recipe/validator.py +++ /dev/null @@ -1,11 +0,0 @@ -import collections.abc -from typing import Any - -import marshmallow as m - -ValidationFunc = collections.abc.Callable[[Any], Any] - - -class ValidationError(m.ValidationError): - def __init__(self, message: str): - super().__init__(message) diff --git a/tests/test_validation.py b/tests/test_validation.py index b0b2dcd..8885bb9 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -2,7 +2,7 @@ import datetime import decimal import uuid -from typing import cast +from typing import Annotated, cast import marshmallow as m import pytest @@ -245,3 +245,36 @@ class IntContainer: with pytest.raises(m.ValidationError): mr.dump(IntContainer(int_field=cast(int, "invalid"))) + + +def test_regexp_validate() -> None: + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class StrContainer: + value1: Annotated[str, mr.str_meta(validate=mr.regexp_validate(r"^[a-z]+$"))] + value2: Annotated[ + str, mr.str_meta(validate=mr.regexp_validate(r"^[a-z]+$", error="String does not match ^[a-z]+$.")) + ] + + with pytest.raises(m.ValidationError) as exc_info: + mr.dump(StrContainer(value1="42", value2="100500")) + + assert exc_info.value.messages == { + "value1": ["String does not match expected pattern."], + "value2": ["String does not match ^[a-z]+$."], + } + + +def test_validate() -> None: + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class IntContainer: + value: Annotated[ + int, + mr.str_meta( + validate=mr.validate(lambda x: x < 0, error="Should be negative."), + ), + ] + + with pytest.raises(m.ValidationError) as exc_info: + mr.dump(IntContainer(value=42)) + + assert exc_info.value.messages == {"value": ["Should be negative."]}