Skip to content

Commit

Permalink
Merge pull request #52 from jakob-keller/typing-and-more
Browse files Browse the repository at this point in the history
`models.Measurement`: improved typing and more
  • Loading branch information
jakob-keller authored Mar 10, 2023
2 parents 2646e8a + f7b0cf2 commit 7445682
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 82 deletions.
121 changes: 61 additions & 60 deletions peprock/models/measurement.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# noinspection PyTypeChecker
"""Generic measurement model.
See https://en.wikipedia.org/wiki/Measurement
Expand All @@ -9,6 +10,9 @@
>>> str(abs(2 * Measurement(decimal.Decimal("-12.3"), MetricPrefix.mega, Unit.watt)))
'24.6 MW'
>>> int(Measurement(0.123456, MetricPrefix.kilo))
123
"""

from __future__ import annotations
Expand All @@ -22,6 +26,15 @@
from .metric_prefix import MetricPrefix
from .unit import Unit

if typing.TYPE_CHECKING:
import sys

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


_MagnitudeT = typing.TypeVar(
"_MagnitudeT",
int,
Expand All @@ -39,18 +52,15 @@


@dataclasses.dataclass(frozen=True)
class Measurement(
typing.SupportsAbs["Measurement[_MagnitudeT]"],
typing.Generic[_MagnitudeT],
):
class Measurement(typing.Generic[_MagnitudeT]):
"""Measurement model supporting conversion and arithmetic operations."""

magnitude: _MagnitudeT
prefix: MetricPrefix = MetricPrefix.NONE
unit: Unit | str | None = None

@functools.cached_property
def _unit_symbol(self: Measurement) -> str:
def _unit_symbol(self: Self) -> str:
match self.unit:
case None | Unit.one:
return ""
Expand All @@ -59,31 +69,31 @@ def _unit_symbol(self: Measurement) -> str:
case _:
return str(self.unit)

def __format__(self: Measurement, format_spec: str) -> str:
def __format__(self: Self, format_spec: str) -> str:
"""Format measurement and return str."""
formatted: str = format(self.magnitude, format_spec)
if suffix := self.prefix.symbol + self._unit_symbol:
return f"{formatted} {suffix}"
return formatted

@functools.cached_property
def _str(self: Measurement) -> str:
def _str(self: Self) -> str:
return format(self)

def __str__(self: Measurement) -> str:
def __str__(self: Self) -> str:
"""Return str(self)."""
return self._str

def _normalize_magnitudes(
self: Measurement[_MagnitudeT],
self: Self,
other: Measurement[_MagnitudeS],
/,
) -> tuple[
Measurement[_MagnitudeT] | Measurement[_MagnitudeS],
Self | Measurement[_MagnitudeS],
_MagnitudeT | float,
_MagnitudeS | float,
]:
target: Measurement[_MagnitudeT] | Measurement[_MagnitudeS] = (
target: Self | Measurement[_MagnitudeS] = (
self if self.prefix <= other.prefix else other
)
return (
Expand All @@ -92,47 +102,47 @@ def _normalize_magnitudes(
other.prefix.convert(other.magnitude, to=target.prefix),
)

def __lt__(self: Measurement, other: Measurement) -> bool:
def __lt__(self: Self, other: Measurement) -> bool:
"""Return self < other."""
if isinstance(other, Measurement) and self.unit == other.unit:
_, magnitude_self, magnitude_other = self._normalize_magnitudes(other)
return magnitude_self < magnitude_other

return NotImplemented

def __le__(self: Measurement, other: Measurement) -> bool:
def __le__(self: Self, other: Measurement) -> bool:
"""Return self <= other."""
if isinstance(other, Measurement) and self.unit == other.unit:
_, magnitude_self, magnitude_other = self._normalize_magnitudes(other)
return magnitude_self <= magnitude_other

return NotImplemented

def __eq__(self: Measurement, other: object) -> bool:
def __eq__(self: Self, other: object) -> bool:
"""Return self == other."""
if isinstance(other, Measurement) and self.unit == other.unit:
_, magnitude_self, magnitude_other = self._normalize_magnitudes(other)
return magnitude_self == magnitude_other

return NotImplemented

def __ne__(self: Measurement, other: object) -> bool:
def __ne__(self: Self, other: object) -> bool:
"""Return self != other."""
if isinstance(other, Measurement) and self.unit == other.unit:
_, magnitude_self, magnitude_other = self._normalize_magnitudes(other)
return magnitude_self != magnitude_other

return NotImplemented

def __gt__(self: Measurement, other: Measurement) -> bool:
def __gt__(self: Self, other: Measurement) -> bool:
"""Return self > other."""
if isinstance(other, Measurement) and self.unit == other.unit:
_, magnitude_self, magnitude_other = self._normalize_magnitudes(other)
return magnitude_self > magnitude_other

return NotImplemented

def __ge__(self: Measurement, other: Measurement) -> bool:
def __ge__(self: Self, other: Measurement) -> bool:
"""Return self >= other."""
if isinstance(other, Measurement) and self.unit == other.unit:
_, magnitude_self, magnitude_other = self._normalize_magnitudes(other)
Expand All @@ -141,20 +151,16 @@ def __ge__(self: Measurement, other: Measurement) -> bool:
return NotImplemented

@functools.cached_property
def _hash(self: Measurement) -> int:
return hash(
(
self.prefix.convert(self.magnitude, to=MetricPrefix.NONE),
self.unit,
),
)
def _hash(self: Self) -> int:
return hash((self.prefix.convert(self.magnitude), self.unit))

def __hash__(self: Measurement) -> int:
def __hash__(self: Self) -> int:
"""Return hash(self)."""
return self._hash

def __abs__(self: _MeasurementT_co) -> _MeasurementT_co:
def __abs__(self: Self) -> Self:
"""Return abs(self)."""
# noinspection PyDataclass
return dataclasses.replace(
self,
magnitude=abs(self.magnitude),
Expand Down Expand Up @@ -209,7 +215,7 @@ def __add__(
) -> Measurement[fractions.Fraction]:
...

def __add__(self: Measurement, other: Measurement) -> Measurement:
def __add__(self, other):
"""Return self + other."""
if isinstance(other, Measurement) and self.unit == other.unit:
target, magnitude_self, magnitude_other = self._normalize_magnitudes(other)
Expand Down Expand Up @@ -290,10 +296,7 @@ def __floordiv__(
) -> decimal.Decimal:
...

def __floordiv__(
self: Measurement,
other: int | float | decimal.Decimal | fractions.Fraction | Measurement,
) -> Measurement | int | float | decimal.Decimal:
def __floordiv__(self, other):
"""Return self // other."""
if isinstance(other, int | float | decimal.Decimal | fractions.Fraction):
return dataclasses.replace(
Expand Down Expand Up @@ -356,10 +359,7 @@ def __mod__(
) -> Measurement[fractions.Fraction]:
...

def __mod__(
self: Measurement,
other: Measurement,
) -> Measurement:
def __mod__(self, other):
"""Return self % other."""
if isinstance(other, Measurement) and self.unit == other.unit:
target, magnitude_self, magnitude_other = self._normalize_magnitudes(other)
Expand Down Expand Up @@ -419,10 +419,7 @@ def __mul__(
) -> Measurement[fractions.Fraction]:
...

def __mul__(
self: Measurement,
other: int | float | decimal.Decimal | fractions.Fraction,
) -> Measurement:
def __mul__(self, other):
"""Return self * other."""
if isinstance(other, int | float | decimal.Decimal | fractions.Fraction):
return dataclasses.replace(
Expand Down Expand Up @@ -481,22 +478,21 @@ def __rmul__(
) -> Measurement[fractions.Fraction]:
...

def __rmul__(
self: Measurement,
other: int | float | decimal.Decimal | fractions.Fraction,
) -> Measurement:
def __rmul__(self, other):
"""Return other * self."""
return self.__mul__(other)

def __neg__(self: Measurement[_MagnitudeT]) -> Measurement[_MagnitudeT]:
def __neg__(self: Self) -> Self:
"""Return -self."""
# noinspection PyDataclass
return dataclasses.replace(
self,
magnitude=-self.magnitude,
)

def __pos__(self: Measurement[_MagnitudeT]) -> Measurement[_MagnitudeT]:
def __pos__(self: Self) -> Self:
"""Return +self."""
# noinspection PyDataclass
return dataclasses.replace(
self,
magnitude=+self.magnitude,
Expand Down Expand Up @@ -551,7 +547,7 @@ def __sub__(
) -> Measurement[fractions.Fraction]:
...

def __sub__(self: Measurement, other: Measurement) -> Measurement:
def __sub__(self, other):
"""Return self - other."""
if isinstance(other, Measurement) and self.unit == other.unit:
target, magnitude_self, magnitude_other = self._normalize_magnitudes(other)
Expand Down Expand Up @@ -660,10 +656,7 @@ def __truediv__(
) -> fractions.Fraction:
...

def __truediv__(
self: Measurement,
other: int | float | decimal.Decimal | fractions.Fraction | Measurement,
) -> Measurement | float | decimal.Decimal | fractions.Fraction:
def __truediv__(self, other):
"""Return self / other."""
if isinstance(other, int | float | decimal.Decimal | fractions.Fraction):
return dataclasses.replace(
Expand All @@ -677,34 +670,42 @@ def __truediv__(

return NotImplemented

def __bool__(self: Measurement) -> bool:
def __bool__(self: Self) -> bool:
"""Return True if magnitude is nonzero; otherwise return False."""
return bool(self.magnitude)

def __int__(self: Self) -> int:
"""Return int(self)."""
return int(self.prefix.convert(self.magnitude))

def __float__(self: Self) -> float:
"""Return float(self)."""
return float(self.prefix.convert(self.magnitude))

@typing.overload
def __round__(self: Measurement) -> Measurement[int]:
def __round__(
self: Self,
) -> Measurement[int]:
...

@typing.overload
def __round__(
self: Measurement[_MagnitudeT],
__ndigits: int,
) -> Measurement[_MagnitudeT]:
self: Self,
__ndigits: typing.SupportsIndex,
) -> Self:
...

def __round__(
self: Measurement,
__ndigits: int | None = None,
) -> Measurement:
self,
__ndigits=None,
):
"""Return round(self)."""
return dataclasses.replace(
self,
magnitude=round(self.magnitude, __ndigits),
)


_MeasurementT_co = typing.TypeVar("_MeasurementT_co", bound=Measurement, covariant=True)

__all__ = [
"Measurement",
]
6 changes: 3 additions & 3 deletions peprock/models/metric_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def convert(
self: MetricPrefix,
__value: int,
/,
to: MetricPrefix,
to: MetricPrefix = NONE, # type: ignore[assignment]
) -> int | float:
...

Expand All @@ -120,15 +120,15 @@ def convert(
self: MetricPrefix,
__value: ComplexT,
/,
to: MetricPrefix,
to: MetricPrefix = NONE, # type: ignore[assignment]
) -> ComplexT:
...

def convert(
self,
__value,
/,
to,
to=NONE,
):
"""Convert value from metric prefix self to to."""
if self is to:
Expand Down
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.10"
typing-extensions = { version = "^4.0.1", python = "<3.11" }

[tool.poetry.group.dev.dependencies]
black = "^23.1.0"
Expand Down
10 changes: 10 additions & 0 deletions tests/models/test_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,16 @@ def test_bool(measurement):
assert bool(measurement) is bool(measurement.magnitude)


def test_int(measurement):
assert int(measurement) == int(measurement.prefix.convert(measurement.magnitude))


def test_float(measurement):
assert float(measurement) == float(
measurement.prefix.convert(measurement.magnitude),
)


@pytest.mark.parametrize(
"ndigits",
[
Expand Down
Loading

0 comments on commit 7445682

Please sign in to comment.