Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only serialize used rx.Base fields #3845

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
80 changes: 78 additions & 2 deletions reflex/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,35 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, List, Type
from typing import TYPE_CHECKING, Any, ClassVar, List, Optional, Type, Union

if TYPE_CHECKING:
from reflex.utils.types import override
else:

def override(fn):
"""Decorator to indicate that a method is meant to override a parent method.

Args:
fn: The method to override.

Returns:
The unmodified method.
"""
return fn


try:
if TYPE_CHECKING:
from pydantic.v1.typing import AbstractSetIntStr, MappingIntStrAny
import pydantic.v1.main as pydantic_main
from pydantic.v1 import BaseModel
from pydantic.v1.fields import ModelField

except ModuleNotFoundError:
if not TYPE_CHECKING:
if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny
else:
import pydantic.main as pydantic_main
from pydantic import BaseModel
from pydantic.fields import ModelField # type: ignore
Expand Down Expand Up @@ -48,6 +69,15 @@ def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None
pydantic_main.validate_field_name = validate_field_name # type: ignore


class UsedSerialization:
"""A mixin which allows tracking of fields used in the frontend.
You can subclass this and add a @rx.serializer which uses the __used_fields__ attribute to only serialize used fields.
Take a look at SlimBase for an example implementation.
"""

__used_fields__: ClassVar[set[str]] = set()


class Base(BaseModel): # pyright: ignore [reportUnboundVariable]
"""The base class subclassed by all Reflex classes.

Expand Down Expand Up @@ -143,3 +173,49 @@ def get_value(self, key: str) -> Any:
exclude_defaults=False,
exclude_none=False,
)


class SlimBase(Base, UsedSerialization):
"""A slimmed down version of the Base class.
Only used fields will be included in the dict for serialization.
"""

@override
def dict(
self,
*,
include: Optional[Union[AbstractSetIntStr, MappingIntStrAny]] = None,
exclude: Optional[Union[AbstractSetIntStr, MappingIntStrAny]] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> dict[str, Any]:
"""Convert the object to a dict.

We override the default dict method to only include fields that are used.

Args:
include: The fields to include.
exclude: The fields to exclude.
by_alias: Whether to use the alias names.
skip_defaults: Whether to skip default values.
exclude_unset: Whether to exclude unset values.
exclude_defaults: Whether to exclude default values.
exclude_none: Whether to exclude None values.

Returns:
The object as a dict.
"""
if not include and isinstance(self, UsedSerialization):
include = self.__used_fields__
return super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
5 changes: 5 additions & 0 deletions reflex/ivars/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
overload,
)

from reflex.base import UsedSerialization
from reflex.utils import types
from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
Expand Down Expand Up @@ -267,6 +268,10 @@ def __getattr__(self, name) -> ImmutableVar:
f"The State var `{str(self)}` has no attribute '{name}' or may have been annotated "
f"wrongly."
)

if issubclass(fixed_type, UsedSerialization):
fixed_type.__used_fields__.add(name)

return ObjectItemOperation.create(self, name, attribute_type).guess_type()
else:
return ObjectItemOperation.create(self, name).guess_type()
Expand Down
34 changes: 33 additions & 1 deletion tests/utils/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import pytest

from reflex.base import Base
from reflex.base import Base, SlimBase
from reflex.components.core.colors import Color
from reflex.ivars.base import ImmutableVar, LiteralVar
from reflex.state import State
from reflex.utils import serializers


Expand Down Expand Up @@ -119,6 +120,33 @@ class BaseSubclass(Base):
ts: datetime.timedelta = datetime.timedelta(1, 1, 1)


class SlimBaseSubclass(SlimBase):
"""A class inheriting from SlimBase for testing."""

ts: datetime.timedelta = datetime.timedelta(1, 1, 1)


def test_slim_base_subclass():
"""Test that a SlimBase subclass is serialized correctly."""

class SlimState(State):
"""A test state."""

slim: SlimBaseSubclass = SlimBaseSubclass(ts=datetime.timedelta(1, 1, 1))

state = SlimState()
assert SlimBaseSubclass.__used_fields__ == set()
assert serializers.serialize(state.slim) == "({ })"

# Access the field
SlimState.slim.ts
assert SlimBaseSubclass.__used_fields__ == {"ts"}
assert serializers.serialize(state.slim) == '({ ["ts"] : "1 day, 0:00:01.000001" })'

# Reset the used fields tracking set.
SlimBaseSubclass.__used_fields__ = set()


@pytest.mark.parametrize(
"value,expected",
[
Expand Down Expand Up @@ -153,6 +181,10 @@ class BaseSubclass(Base):
BaseSubclass(ts=datetime.timedelta(1, 1, 1)),
'({ ["ts"] : "1 day, 0:00:01.000001" })',
),
(
SlimBaseSubclass(ts=datetime.timedelta(1, 1, 1)),
"({ })",
),
(
[1, LiteralVar.create("hi"), ImmutableVar.create("bye")],
'[1, "hi", bye]',
Expand Down
Loading