From 79e8bca7eaee5db534a0be6222172f852444ebcb Mon Sep 17 00:00:00 2001 From: osoken Date: Tue, 31 Oct 2023 23:22:21 +0900 Subject: [PATCH] feat(models): add NonNegativeFloat class --- birdxplorer/models.py | 104 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 103 insertions(+), 1 deletion(-) diff --git a/birdxplorer/models.py b/birdxplorer/models.py index eb8c094..b24bbc6 100644 --- a/birdxplorer/models.py +++ b/birdxplorer/models.py @@ -11,6 +11,7 @@ IncEx: TypeAlias = "set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None" StrT = TypeVar("StrT", bound="BaseString") IntT = TypeVar("IntT", bound="BaseInt") +FloatT = TypeVar("FloatT", bound="BaseFloat") class BaseString(str): @@ -218,6 +219,107 @@ def min_value(cls) -> int: return int(datetime(2006, 7, 15, 0, 0, 0, 0, timezone.utc).timestamp() * 1000) +class BaseFloat(float): + """ + >>> BaseFloat(1.0) + BaseFloat(1.0) + >>> float(BaseFloat(1.0)) + 1.0 + >>> ta = TypeAdapter(BaseFloat) + >>> ta.validate_python(1.0) + BaseFloat(1.0) + >>> ta.validate_python("1.0") + BaseFloat(1.0) + >>> ta.dump_json(BaseFloat(1.0)) + b'1.0' + >>> BaseFloat.from_float(1.0) + BaseFloat(1.0) + >>> ta.validate_python(BaseFloat.from_float(1.0)) + BaseFloat(1.0) + >>> hash(BaseFloat(1.0)) == hash(1.0) + True + """ + + @classmethod + def _proc_float(cls, f: float) -> float: + return f + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + def __float__(self) -> float: + return super(BaseFloat, self).__float__() + + @classmethod + def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.no_info_after_validator_function( + cls.validate, + core_schema.float_schema(**cls.__get_extra_constraint_dict__()), + serialization=core_schema.plain_serializer_function_ser_schema(cls.serialize, when_used="json"), + ) + + @classmethod + def validate(cls: Type[FloatT], v: Any) -> FloatT: + return cls(cls._proc_float(v)) + + def serialize(self) -> float: + return float(self) + + @classmethod + def __get_extra_constraint_dict__(cls) -> dict[str, Any]: + return {} + + def __hash__(self) -> int: + return super(BaseFloat, self).__hash__() + + @classmethod + def from_float(cls: Type[FloatT], v: float) -> FloatT: + return TypeAdapter(cls).validate_python(v) + + +class BaseLowerBoundedFloat(BaseFloat, ABC): + """ + >>> class PositiveFloat(BaseLowerBoundedFloat): + ... @classmethod + ... def min_value(cls) -> float: + ... return 0.0 + >>> PositiveFloat.from_float(0.0) + PositiveFloat(0.0) + >>> PositiveFloat.from_float(-0.1) + Traceback (most recent call last): + ... + pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-float] + Input should be greater than or equal to 0 [type=greater_than_equal, input_value=-0.1, input_type=float] + ... + """ + + @classmethod + @abstractmethod + def min_value(cls) -> float: + raise NotImplementedError + + @classmethod + def __get_extra_constraint_dict__(cls) -> dict[str, Any]: + return dict(super().__get_extra_constraint_dict__(), ge=cls.min_value()) + + +class NonNegativeFloat(BaseLowerBoundedFloat): + """ + >>> NonNegativeFloat.from_float(0.0) + NonNegativeFloat(0.0) + >>> NonNegativeFloat.from_float(-0.1) + Traceback (most recent call last): + ... + pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-float] + Input should be greater than or equal to 0 [type=greater_than_equal, input_value=-0.1, input_type=float] + ... + """ + + @classmethod + def min_value(cls) -> float: + return 0.0 + + class BaseModel(PydanticBaseModel): """ >>> from unittest.mock import patch @@ -296,7 +398,7 @@ class UserEnrollment(BaseModel): timestamp_of_last_state_change: UserEnrollmentLastStateChangeTimeStamp timestamp_of_last_earn_out: str modeling_population: str - modeling_group: str + modeling_group: NonNegativeFloat class NoteId(BaseString):