diff --git a/reflex/base.py b/reflex/base.py index 5a2a6d2a98..8509e9d345 100644 --- a/reflex/base.py +++ b/reflex/base.py @@ -3,14 +3,35 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, List, Type +from typing import TYPE_CHECKING, Any, ClassVar, List, Optional, Type, Union + +if TYPE_CHECKING: + from reflex.utils.types import override +else: + + def override(fn): + """Decorator to indicate that a method is meant to override a parent method. + + Args: + fn: The method to override. + + Returns: + The unmodified method. + """ + return fn + try: + if TYPE_CHECKING: + from pydantic.v1.typing import AbstractSetIntStr, MappingIntStrAny import pydantic.v1.main as pydantic_main from pydantic.v1 import BaseModel from pydantic.v1.fields import ModelField + except ModuleNotFoundError: - if not TYPE_CHECKING: + if TYPE_CHECKING: + from pydantic.typing import AbstractSetIntStr, MappingIntStrAny + else: import pydantic.main as pydantic_main from pydantic import BaseModel from pydantic.fields import ModelField # type: ignore @@ -48,6 +69,15 @@ def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None pydantic_main.validate_field_name = validate_field_name # type: ignore +class UsedSerialization: + """A mixin which allows tracking of fields used in the frontend. + You can subclass this and add a @rx.serializer which uses the __used_fields__ attribute to only serialize used fields. + Take a look at SlimBase for an example implementation. + """ + + __used_fields__: ClassVar[set[str]] = set() + + class Base(BaseModel): # pyright: ignore [reportUnboundVariable] """The base class subclassed by all Reflex classes. @@ -143,3 +173,49 @@ def get_value(self, key: str) -> Any: exclude_defaults=False, exclude_none=False, ) + + +class SlimBase(Base, UsedSerialization): + """A slimmed down version of the Base class. + Only used fields will be included in the dict for serialization. + """ + + @override + def dict( + self, + *, + include: Optional[Union[AbstractSetIntStr, MappingIntStrAny]] = None, + exclude: Optional[Union[AbstractSetIntStr, MappingIntStrAny]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> dict[str, Any]: + """Convert the object to a dict. + + We override the default dict method to only include fields that are used. + + Args: + include: The fields to include. + exclude: The fields to exclude. + by_alias: Whether to use the alias names. + skip_defaults: Whether to skip default values. + exclude_unset: Whether to exclude unset values. + exclude_defaults: Whether to exclude default values. + exclude_none: Whether to exclude None values. + + Returns: + The object as a dict. + """ + if not include and isinstance(self, UsedSerialization): + include = self.__used_fields__ + return super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) diff --git a/reflex/ivars/object.py b/reflex/ivars/object.py index 687083fcb9..b551f7ca0e 100644 --- a/reflex/ivars/object.py +++ b/reflex/ivars/object.py @@ -20,6 +20,7 @@ overload, ) +from reflex.base import UsedSerialization from reflex.utils import types from reflex.utils.exceptions import VarAttributeError from reflex.utils.types import GenericType, get_attribute_access_type, get_origin @@ -267,6 +268,10 @@ def __getattr__(self, name) -> ImmutableVar: f"The State var `{str(self)}` has no attribute '{name}' or may have been annotated " f"wrongly." ) + + if issubclass(fixed_type, UsedSerialization): + fixed_type.__used_fields__.add(name) + return ObjectItemOperation.create(self, name, attribute_type).guess_type() else: return ObjectItemOperation.create(self, name).guess_type() diff --git a/tests/utils/test_serializers.py b/tests/utils/test_serializers.py index 7f5c2bc66e..40b95f4cc6 100644 --- a/tests/utils/test_serializers.py +++ b/tests/utils/test_serializers.py @@ -5,9 +5,10 @@ import pytest -from reflex.base import Base +from reflex.base import Base, SlimBase from reflex.components.core.colors import Color from reflex.ivars.base import ImmutableVar, LiteralVar +from reflex.state import State from reflex.utils import serializers @@ -119,6 +120,33 @@ class BaseSubclass(Base): ts: datetime.timedelta = datetime.timedelta(1, 1, 1) +class SlimBaseSubclass(SlimBase): + """A class inheriting from SlimBase for testing.""" + + ts: datetime.timedelta = datetime.timedelta(1, 1, 1) + + +def test_slim_base_subclass(): + """Test that a SlimBase subclass is serialized correctly.""" + + class SlimState(State): + """A test state.""" + + slim: SlimBaseSubclass = SlimBaseSubclass(ts=datetime.timedelta(1, 1, 1)) + + state = SlimState() + assert SlimBaseSubclass.__used_fields__ == set() + assert serializers.serialize(state.slim) == "({ })" + + # Access the field + SlimState.slim.ts + assert SlimBaseSubclass.__used_fields__ == {"ts"} + assert serializers.serialize(state.slim) == '({ ["ts"] : "1 day, 0:00:01.000001" })' + + # Reset the used fields tracking set. + SlimBaseSubclass.__used_fields__ = set() + + @pytest.mark.parametrize( "value,expected", [ @@ -153,6 +181,10 @@ class BaseSubclass(Base): BaseSubclass(ts=datetime.timedelta(1, 1, 1)), '({ ["ts"] : "1 day, 0:00:01.000001" })', ), + ( + SlimBaseSubclass(ts=datetime.timedelta(1, 1, 1)), + "({ })", + ), ( [1, LiteralVar.create("hi"), ImmutableVar.create("bye")], '[1, "hi", bye]',