Skip to content

Commit

Permalink
use serializer for state update and rework serializers (#3934)
Browse files Browse the repository at this point in the history
* use serializer for state update and rework serializers

* format
  • Loading branch information
adhami3310 authored Sep 16, 2024
1 parent 37920d6 commit a57095f
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 122 deletions.
24 changes: 1 addition & 23 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import dataclasses
import functools
import inspect
import json
import os
import uuid
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -206,27 +205,6 @@ def __init__(self, router_data: Optional[dict] = None):
object.__setattr__(self, "headers", HeaderData(router_data))
object.__setattr__(self, "page", PageData(router_data))

def toJson(self) -> str:
"""Convert the object to a JSON string.
Returns:
The JSON string.
"""
return json.dumps(dataclasses.asdict(self))


@serializer
def serialize_routerdata(value: RouterData) -> str:
"""Serialize a RouterData instance.
Args:
value: The RouterData to serialize.
Returns:
The serialized RouterData.
"""
return value.toJson()


def _no_chain_background_task(
state_cls: Type["BaseState"], name: str, fn: Callable
Expand Down Expand Up @@ -2415,7 +2393,7 @@ def json(self) -> str:
Returns:
The state update as a JSON string.
"""
return json.dumps(dataclasses.asdict(self))
return format.json_dumps(dataclasses.asdict(self))


class StateManager(Base, ABC):
Expand Down
24 changes: 2 additions & 22 deletions reflex/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import dataclasses
import inspect
import json
import os
Expand Down Expand Up @@ -410,22 +409,11 @@ def format_props(*single_props, **key_value_props) -> list[str]:

return [
(
f"{name}={format_prop(prop)}"
if isinstance(prop, Var) and not isinstance(prop, Var)
else (
f"{name}={{{format_prop(prop if isinstance(prop, Var) else LiteralVar.create(prop))}}}"
)
f"{name}={{{format_prop(prop if isinstance(prop, Var) else LiteralVar.create(prop))}}}"
)
for name, prop in sorted(key_value_props.items())
if prop is not None
] + [
(
str(prop)
if isinstance(prop, Var) and not isinstance(prop, Var)
else f"{str(LiteralVar.create(prop))}"
)
for prop in single_props
]
] + [(f"{str(LiteralVar.create(prop))}") for prop in single_props]


def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
Expand Down Expand Up @@ -623,14 +611,6 @@ def format_state(value: Any, key: Optional[str] = None) -> Any:
if isinstance(value, dict):
return {k: format_state(v, k) for k, v in value.items()}

# Hand dataclasses.
if dataclasses.is_dataclass(value):
if isinstance(value, type):
raise TypeError(
f"Cannot format state of type {type(value)}. Please provide an instance of the dataclass."
)
return {k: format_state(v, k) for k, v in dataclasses.asdict(value).items()}

# Handle lists, sets, typles.
if isinstance(value, types.StateIterBases):
return [format_state(v) for v in value]
Expand Down
29 changes: 12 additions & 17 deletions reflex/utils/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import dataclasses
import functools
import json
import warnings
Expand Down Expand Up @@ -29,7 +30,7 @@

# Mapping from type to a serializer.
# The serializer should convert the type to a JSON object.
SerializedType = Union[str, bool, int, float, list, dict]
SerializedType = Union[str, bool, int, float, list, dict, None]


Serializer = Callable[[Type], SerializedType]
Expand Down Expand Up @@ -124,6 +125,8 @@ def serialize(

# If there is no serializer, return None.
if serializer is None:
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return serialize(dataclasses.asdict(value))
if get_type:
return None, None
return None
Expand Down Expand Up @@ -225,7 +228,7 @@ def serialize_str(value: str) -> str:


@serializer
def serialize_primitive(value: Union[bool, int, float, None]) -> str:
def serialize_primitive(value: Union[bool, int, float, None]):
"""Serialize a primitive type.
Args:
Expand All @@ -234,13 +237,11 @@ def serialize_primitive(value: Union[bool, int, float, None]) -> str:
Returns:
The serialized number/bool/None.
"""
from reflex.utils import format

return format.json_dumps(value)
return value


@serializer
def serialize_base(value: Base) -> str:
def serialize_base(value: Base) -> dict:
"""Serialize a Base instance.
Args:
Expand All @@ -249,13 +250,11 @@ def serialize_base(value: Base) -> str:
Returns:
The serialized Base.
"""
from reflex.vars import LiteralVar

return str(LiteralVar.create(value))
return {k: serialize(v) for k, v in value.dict().items() if not callable(v)}


@serializer
def serialize_list(value: Union[List, Tuple, Set]) -> str:
def serialize_list(value: Union[List, Tuple, Set]) -> list:
"""Serialize a list to a JSON string.
Args:
Expand All @@ -264,13 +263,11 @@ def serialize_list(value: Union[List, Tuple, Set]) -> str:
Returns:
The serialized list.
"""
from reflex.vars import LiteralArrayVar

return str(LiteralArrayVar.create(value))
return [serialize(item) for item in value]


@serializer
def serialize_dict(prop: Dict[str, Any]) -> str:
def serialize_dict(prop: Dict[str, Any]) -> dict:
"""Serialize a dictionary to a JSON string.
Args:
Expand All @@ -279,9 +276,7 @@ def serialize_dict(prop: Dict[str, Any]) -> str:
Returns:
The serialized dictionary.
"""
from reflex.vars import LiteralObjectVar

return str(LiteralObjectVar.create(prop))
return {k: serialize(v) for k, v in prop.items()}


@serializer(to=str)
Expand Down
28 changes: 13 additions & 15 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,21 +936,6 @@ def json(self) -> str:
OUTPUT = TypeVar("OUTPUT", bound=Var)


def _encode_var(value: Var) -> str:
"""Encode the state name into a formatted var.
Args:
value: The value to encode the state name into.
Returns:
The encoded var.
"""
return f"{value}"


serializers.serializer(_encode_var)


class LiteralVar(Var):
"""Base class for immutable literal vars."""

Expand Down Expand Up @@ -1101,6 +1086,19 @@ def json(self) -> str:
)


@serializers.serializer
def serialize_literal(value: LiteralVar):
"""Serialize a Literal type.
Args:
value: The Literal to serialize.
Returns:
The serialized Literal.
"""
return serializers.serialize(value._var_value)


P = ParamSpec("P")
T = TypeVar("T")

Expand Down
42 changes: 21 additions & 21 deletions tests/utils/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,28 +352,28 @@ def test_format_match(
"prop,formatted",
[
("string", '"string"'),
("{wrapped_string}", "{wrapped_string}"),
(True, "{true}"),
(False, "{false}"),
(123, "{123}"),
(3.14, "{3.14}"),
([1, 2, 3], "{[1, 2, 3]}"),
(["a", "b", "c"], '{["a", "b", "c"]}'),
({"a": 1, "b": 2, "c": 3}, '{({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })}'),
({"a": 'foo "bar" baz'}, r'{({ ["a"] : "foo \"bar\" baz" })}'),
("{wrapped_string}", '"{wrapped_string}"'),
(True, "true"),
(False, "false"),
(123, "123"),
(3.14, "3.14"),
([1, 2, 3], "[1, 2, 3]"),
(["a", "b", "c"], '["a", "b", "c"]'),
({"a": 1, "b": 2, "c": 3}, '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })'),
({"a": 'foo "bar" baz'}, r'({ ["a"] : "foo \"bar\" baz" })'),
(
{
"a": 'foo "{ "bar" }" baz',
"b": Var(_js_expr="val", _var_type=str).guess_type(),
},
r'{({ ["a"] : "foo \"{ \"bar\" }\" baz", ["b"] : val })}',
r'({ ["a"] : "foo \"{ \"bar\" }\" baz", ["b"] : val })',
),
(
EventChain(
events=[EventSpec(handler=EventHandler(fn=mock_event))],
args_spec=lambda: [],
),
'{(...args) => addEvents([Event("mock_event", {})], args, {})}',
'((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ })))))',
),
(
EventChain(
Expand All @@ -382,7 +382,7 @@ def test_format_match(
handler=EventHandler(fn=mock_event),
args=(
(
LiteralVar.create("arg"),
Var(_js_expr="arg"),
Var(
_js_expr="_e",
)
Expand All @@ -394,25 +394,25 @@ def test_format_match(
],
args_spec=lambda e: [e.target.value],
),
'{(_e) => addEvents([Event("mock_event", {"arg":_e["target"]["value"]})], [_e], {})}',
'((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] })))], [_e], ({ })))))',
),
(
EventChain(
events=[EventSpec(handler=EventHandler(fn=mock_event))],
args_spec=lambda: [],
event_actions={"stopPropagation": True},
),
'{(...args) => addEvents([Event("mock_event", {})], args, {"stopPropagation": true})}',
'((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["stopPropagation"] : true })))))',
),
(
EventChain(
events=[EventSpec(handler=EventHandler(fn=mock_event))],
args_spec=lambda: [],
event_actions={"preventDefault": True},
),
'{(...args) => addEvents([Event("mock_event", {})], args, {"preventDefault": true})}',
'((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["preventDefault"] : true })))))',
),
({"a": "red", "b": "blue"}, '{({ ["a"] : "red", ["b"] : "blue" })}'),
({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'),
(Var(_js_expr="var", _var_type=int).guess_type(), "var"),
(
Var(
Expand All @@ -427,15 +427,15 @@ def test_format_match(
),
(
{"a": Var(_js_expr="val", _var_type=str).guess_type()},
'{({ ["a"] : val })}',
'({ ["a"] : val })',
),
(
{"a": Var(_js_expr='"val"', _var_type=str).guess_type()},
'{({ ["a"] : "val" })}',
'({ ["a"] : "val" })',
),
(
{"a": Var(_js_expr='state.colors["val"]', _var_type=str).guess_type()},
'{({ ["a"] : state.colors["val"] })}',
'({ ["a"] : state.colors["val"] })',
),
# tricky real-world case from markdown component
(
Expand All @@ -444,7 +444,7 @@ def test_format_match(
_js_expr=f"(({{node, ...props}}) => <Heading {{...props}} {''.join(Tag(name='', props=Style({'as_': 'h1'})).format_props())} />)"
),
},
'{({ ["h1"] : (({node, ...props}) => <Heading {...props} as={"h1"} />) })}',
'({ ["h1"] : (({node, ...props}) => <Heading {...props} as={"h1"} />) })',
),
],
)
Expand All @@ -455,7 +455,7 @@ def test_format_prop(prop: Var, formatted: str):
prop: The prop to test.
formatted: The expected formatted value.
"""
assert format.format_prop(prop) == formatted
assert format.format_prop(LiteralVar.create(prop)) == formatted


@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit a57095f

Please sign in to comment.