From 29c24f17c06b0b8d5363ace3804e009b4ddb476e Mon Sep 17 00:00:00 2001 From: Mike Shultz Date: Sat, 4 May 2024 11:36:41 -0600 Subject: [PATCH] fix: error checking type for "subscripted generics" (#77) * fix: error checking type for "subscripted generics" * fix(type): get_args still needed for union check * refactor: introduce `is_scalar_type` instead * refactor: change name for better readability * refactor: use Datapoints root model to assist in automated conversion * fix: unused imports * test: add test for Datapoints parsing * fix: add dict methods to Datapoints --------- Co-authored-by: fubuloubu <3859395+fubuloubu@users.noreply.github.com> --- silverback/recorder.py | 46 +++++++------------------------------ silverback/types.py | 52 +++++++++++++++++++++++++++++++++++++++--- tests/test_types.py | 38 ++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 41 deletions(-) create mode 100644 tests/test_types.py diff --git a/silverback/recorder.py b/silverback/recorder.py index 0b8d2018..70e521bc 100644 --- a/silverback/recorder.py +++ b/silverback/recorder.py @@ -7,15 +7,7 @@ from taskiq import TaskiqResult from typing_extensions import Self # Introduced 3.11 -from .types import ( - Datapoint, - ScalarDatapoint, - ScalarType, - SilverbackID, - UTCTimestamp, - iso_format, - utc_now, -) +from .types import Datapoints, SilverbackID, UTCTimestamp, iso_format, utc_now logger = get_logger(__name__) @@ -35,39 +27,17 @@ class TaskResult(BaseModel): block_number: int | None = None # Custom user metrics here - metrics: dict[str, Datapoint] = {} + metrics: Datapoints @classmethod - def _extract_custom_metrics(cls, result: Any, task_name: str) -> dict[str, Datapoint]: - if isinstance(result, Datapoint): # type: ignore[arg-type,misc] - return {"result": result} + def _extract_custom_metrics(cls, return_value: Any, task_name: str) -> Datapoints: + if return_value is None: + return Datapoints(root={}) - elif isinstance(result, ScalarType): # type: ignore[arg-type,misc] - return {"result": ScalarDatapoint(data=result)} + elif not isinstance(return_value, dict): + return_value = {"return_value": return_value} - elif result is None: - return {} - - elif not isinstance(result, dict): - logger.warning(f"Cannot handle return type of '{task_name}': '{type(result)}'.") - return {} - - converted_result = {} - - for metric_name, metric_value in result.items(): - if isinstance(metric_value, Datapoint): # type: ignore[arg-type,misc] - converted_result[metric_name] = metric_value - - elif isinstance(metric_value, ScalarType): # type: ignore[arg-type,misc] - converted_result[metric_name] = ScalarDatapoint(data=metric_value) - - else: - logger.warning( - f"Cannot handle type of metric '{task_name}.{metric_name}':" - f" '{type(metric_value)}'." - ) - - return converted_result + return Datapoints(root=return_value) @classmethod def _extract_system_metrics(cls, labels: dict) -> dict: diff --git a/silverback/types.py b/silverback/types.py index 6448b72c..efda3dfa 100644 --- a/silverback/types.py +++ b/silverback/types.py @@ -3,10 +3,13 @@ from enum import Enum # NOTE: `enum.StrEnum` only in Python 3.11+ from typing import Literal -from pydantic import BaseModel, Field +from ape.logging import get_logger +from pydantic import BaseModel, Field, RootModel, ValidationError, model_validator from pydantic.functional_serializers import PlainSerializer from typing_extensions import Annotated +logger = get_logger(__name__) + class TaskType(str, Enum): STARTUP = "startup" @@ -47,6 +50,8 @@ class _BaseDatapoint(BaseModel): Int96 = Annotated[int, Field(ge=-(2**95), le=2**95 - 1)] # NOTE: only these types of data are implicitly converted e.g. `{"something": 1, "else": 0.001}` ScalarType = bool | Int96 | float | Decimal +# NOTE: Interesting side effect is that `int` outside the INT96 range parse as `Decimal` +# This is okay, preferable actually, because it means we can store ints outside that range class ScalarDatapoint(_BaseDatapoint): @@ -54,7 +59,48 @@ class ScalarDatapoint(_BaseDatapoint): data: ScalarType -# NOTE: Other datapoint types must be explicitly used +# NOTE: Other datapoint types must be explicitly defined as subclasses of `_BaseDatapoint` +# Users will have to import and use these directly -# TODO: Other datapoint types added to union here... +# NOTE: Other datapoint types must be added to this union Datapoint = ScalarDatapoint + + +class Datapoints(RootModel): + root: dict[str, Datapoint] + + @model_validator(mode="before") + def parse_datapoints(cls, datapoints: dict) -> dict: + names_to_remove: dict[str, ValidationError] = {} + # Automatically convert raw scalar types + for name in datapoints: + if not isinstance(datapoints[name], Datapoint): + try: + datapoints[name] = ScalarDatapoint(data=datapoints[name]) + except ValidationError as e: + names_to_remove[name] = e + + # Prune and raise a warning about unconverted datapoints + for name in names_to_remove: + data = datapoints.pop(name) + logger.warning( + f"Cannot convert datapoint '{name}' of type '{type(data)}': {names_to_remove[name]}" + ) + + return datapoints + + # Add dict methods + def get(self, key: str, default: Datapoint | None = None) -> Datapoint | None: + if key in self: + return self[key] + + return default + + def __iter__(self): + return iter(self.root) + + def __getitem__(self, item): + return self.root[item] + + def items(self): + return self.root.items() diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 00000000..c87e3120 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,38 @@ +from decimal import Decimal + +import pytest + +from silverback.types import Datapoints + + +@pytest.mark.parametrize( + "raw_return,expected", + [ + # String datapoints don't parse (empty datapoints) + ({"a": "b"}, {}), + # ints parse + ({"a": 1}, {"a": {"type": "scalar", "data": 1}}), + # max INT96 value + ( + {"a": 2**96 - 1}, + {"a": {"type": "scalar", "data": 79228162514264337593543950335}}, + ), + # int over INT96 max parses as Decimal + ( + {"a": 2**96}, + {"a": {"type": "scalar", "data": Decimal("79228162514264337593543950336")}}, + ), + # Decimal parses as Decimal + ( + {"a": Decimal("1e12")}, + {"a": {"type": "scalar", "data": Decimal("1000000000000")}}, + ), + # float parses as float + ( + {"a": 1e12}, + {"a": {"type": "scalar", "data": 1000000000000.0}}, + ), + ], +) +def test_datapoint_parsing(raw_return, expected): + assert Datapoints(root=raw_return).model_dump() == expected