Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(models): add NonNegativeFloat class #27

Merged
merged 2 commits into from
Nov 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion birdxplorer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
IncEx: TypeAlias = "set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None"
StrT = TypeVar("StrT", bound="BaseString")
IntT = TypeVar("IntT", bound="BaseInt")
FloatT = TypeVar("FloatT", bound="BaseFloat")


class BaseString(str):
Expand Down Expand Up @@ -258,6 +259,107 @@ def min_value(cls) -> int:
return int(datetime(2006, 7, 15, 0, 0, 0, 0, timezone.utc).timestamp() * 1000)


class BaseFloat(float):
"""
>>> BaseFloat(1.0)
BaseFloat(1.0)
>>> float(BaseFloat(1.0))
1.0
>>> ta = TypeAdapter(BaseFloat)
>>> ta.validate_python(1.0)
BaseFloat(1.0)
>>> ta.validate_python("1.0")
BaseFloat(1.0)
>>> ta.dump_json(BaseFloat(1.0))
b'1.0'
>>> BaseFloat.from_float(1.0)
BaseFloat(1.0)
>>> ta.validate_python(BaseFloat.from_float(1.0))
BaseFloat(1.0)
>>> hash(BaseFloat(1.0)) == hash(1.0)
True
"""

@classmethod
def _proc_float(cls, f: float) -> float:
return f

def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"

def __float__(self) -> float:
return super(BaseFloat, self).__float__()

@classmethod
def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
cls.validate,
core_schema.float_schema(**cls.__get_extra_constraint_dict__()),
serialization=core_schema.plain_serializer_function_ser_schema(cls.serialize, when_used="json"),
)

@classmethod
def validate(cls: Type[FloatT], v: Any) -> FloatT:
return cls(cls._proc_float(v))

def serialize(self) -> float:
return float(self)

@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return {}

def __hash__(self) -> int:
return super(BaseFloat, self).__hash__()

@classmethod
def from_float(cls: Type[FloatT], v: float) -> FloatT:
return TypeAdapter(cls).validate_python(v)


class BaseLowerBoundedFloat(BaseFloat, ABC):
"""
>>> class PositiveFloat(BaseLowerBoundedFloat):
... @classmethod
... def min_value(cls) -> float:
... return 0.0
>>> PositiveFloat.from_float(0.0)
PositiveFloat(0.0)
>>> PositiveFloat.from_float(-0.1)
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-float]
Input should be greater than or equal to 0 [type=greater_than_equal, input_value=-0.1, input_type=float]
...
"""

@classmethod
@abstractmethod
def min_value(cls) -> float:
raise NotImplementedError

@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), ge=cls.min_value())


class NonNegativeFloat(BaseLowerBoundedFloat):
"""
>>> NonNegativeFloat.from_float(0.0)
NonNegativeFloat(0.0)
>>> NonNegativeFloat.from_float(-0.1)
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-float]
Input should be greater than or equal to 0 [type=greater_than_equal, input_value=-0.1, input_type=float]
...
"""

@classmethod
def min_value(cls) -> float:
return 0.0


class BaseModel(PydanticBaseModel):
"""
>>> from unittest.mock import patch
Expand Down Expand Up @@ -342,7 +444,7 @@ class UserEnrollment(BaseModel):
timestamp_of_last_state_change: UserEnrollmentLastStateChangeTimeStamp
timestamp_of_last_earn_out: UserEnrollmentLastEarnOutTimestamp
modeling_population: ModelingPopulation
modeling_group: str
modeling_group: NonNegativeFloat


class NoteId(BaseString):
Expand Down