Skip to content

Commit

Permalink
fix: error checking type for "subscripted generics" (#77)
Browse files Browse the repository at this point in the history
* fix: error checking type for "subscripted generics"

* fix(type): get_args still needed for union check

* refactor: introduce `is_scalar_type` instead

* refactor: change name for better readability

* refactor: use Datapoints root model to assist in automated conversion

* fix: unused imports

* test: add test for Datapoints parsing

* fix: add dict methods to Datapoints

---------

Co-authored-by: fubuloubu <[email protected]>
  • Loading branch information
mikeshultz and fubuloubu authored May 4, 2024
1 parent 1b49588 commit 29c24f1
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 41 deletions.
46 changes: 8 additions & 38 deletions silverback/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,7 @@
from taskiq import TaskiqResult
from typing_extensions import Self # Introduced 3.11

from .types import (
Datapoint,
ScalarDatapoint,
ScalarType,
SilverbackID,
UTCTimestamp,
iso_format,
utc_now,
)
from .types import Datapoints, SilverbackID, UTCTimestamp, iso_format, utc_now

logger = get_logger(__name__)

Expand All @@ -35,39 +27,17 @@ class TaskResult(BaseModel):
block_number: int | None = None

# Custom user metrics here
metrics: dict[str, Datapoint] = {}
metrics: Datapoints

@classmethod
def _extract_custom_metrics(cls, result: Any, task_name: str) -> dict[str, Datapoint]:
if isinstance(result, Datapoint): # type: ignore[arg-type,misc]
return {"result": result}
def _extract_custom_metrics(cls, return_value: Any, task_name: str) -> Datapoints:
if return_value is None:
return Datapoints(root={})

elif isinstance(result, ScalarType): # type: ignore[arg-type,misc]
return {"result": ScalarDatapoint(data=result)}
elif not isinstance(return_value, dict):
return_value = {"return_value": return_value}

elif result is None:
return {}

elif not isinstance(result, dict):
logger.warning(f"Cannot handle return type of '{task_name}': '{type(result)}'.")
return {}

converted_result = {}

for metric_name, metric_value in result.items():
if isinstance(metric_value, Datapoint): # type: ignore[arg-type,misc]
converted_result[metric_name] = metric_value

elif isinstance(metric_value, ScalarType): # type: ignore[arg-type,misc]
converted_result[metric_name] = ScalarDatapoint(data=metric_value)

else:
logger.warning(
f"Cannot handle type of metric '{task_name}.{metric_name}':"
f" '{type(metric_value)}'."
)

return converted_result
return Datapoints(root=return_value)

@classmethod
def _extract_system_metrics(cls, labels: dict) -> dict:
Expand Down
52 changes: 49 additions & 3 deletions silverback/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from enum import Enum # NOTE: `enum.StrEnum` only in Python 3.11+
from typing import Literal

from pydantic import BaseModel, Field
from ape.logging import get_logger
from pydantic import BaseModel, Field, RootModel, ValidationError, model_validator
from pydantic.functional_serializers import PlainSerializer
from typing_extensions import Annotated

logger = get_logger(__name__)


class TaskType(str, Enum):
STARTUP = "startup"
Expand Down Expand Up @@ -47,14 +50,57 @@ class _BaseDatapoint(BaseModel):
Int96 = Annotated[int, Field(ge=-(2**95), le=2**95 - 1)]
# NOTE: only these types of data are implicitly converted e.g. `{"something": 1, "else": 0.001}`
ScalarType = bool | Int96 | float | Decimal
# NOTE: Interesting side effect is that `int` outside the INT96 range parse as `Decimal`
# This is okay, preferable actually, because it means we can store ints outside that range


class ScalarDatapoint(_BaseDatapoint):
type: Literal["scalar"] = "scalar"
data: ScalarType


# NOTE: Other datapoint types must be explicitly used
# NOTE: Other datapoint types must be explicitly defined as subclasses of `_BaseDatapoint`
# Users will have to import and use these directly

# TODO: Other datapoint types added to union here...
# NOTE: Other datapoint types must be added to this union
Datapoint = ScalarDatapoint


class Datapoints(RootModel):
root: dict[str, Datapoint]

@model_validator(mode="before")
def parse_datapoints(cls, datapoints: dict) -> dict:
names_to_remove: dict[str, ValidationError] = {}
# Automatically convert raw scalar types
for name in datapoints:
if not isinstance(datapoints[name], Datapoint):
try:
datapoints[name] = ScalarDatapoint(data=datapoints[name])
except ValidationError as e:
names_to_remove[name] = e

# Prune and raise a warning about unconverted datapoints
for name in names_to_remove:
data = datapoints.pop(name)
logger.warning(
f"Cannot convert datapoint '{name}' of type '{type(data)}': {names_to_remove[name]}"
)

return datapoints

# Add dict methods
def get(self, key: str, default: Datapoint | None = None) -> Datapoint | None:
if key in self:
return self[key]

return default

def __iter__(self):
return iter(self.root)

def __getitem__(self, item):
return self.root[item]

def items(self):
return self.root.items()
38 changes: 38 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from decimal import Decimal

import pytest

from silverback.types import Datapoints


@pytest.mark.parametrize(
"raw_return,expected",
[
# String datapoints don't parse (empty datapoints)
({"a": "b"}, {}),
# ints parse
({"a": 1}, {"a": {"type": "scalar", "data": 1}}),
# max INT96 value
(
{"a": 2**96 - 1},
{"a": {"type": "scalar", "data": 79228162514264337593543950335}},
),
# int over INT96 max parses as Decimal
(
{"a": 2**96},
{"a": {"type": "scalar", "data": Decimal("79228162514264337593543950336")}},
),
# Decimal parses as Decimal
(
{"a": Decimal("1e12")},
{"a": {"type": "scalar", "data": Decimal("1000000000000")}},
),
# float parses as float
(
{"a": 1e12},
{"a": {"type": "scalar", "data": 1000000000000.0}},
),
],
)
def test_datapoint_parsing(raw_return, expected):
assert Datapoints(root=raw_return).model_dump() == expected

0 comments on commit 29c24f1

Please sign in to comment.