diff --git a/reflex/state.py b/reflex/state.py index 26bef5d7ec..8210c15009 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -691,6 +691,9 @@ def _init_var_dependency_dicts(cls): parent_state.get_parent_state(), ) + # Reset cached schema value + cls._to_schema.cache_clear() + @classmethod def _check_overridden_methods(cls): """Check for shadow methods and raise error if any. @@ -1945,20 +1948,58 @@ def __getstate__(self): The state dict for serialization. """ state = super().__getstate__() - # Never serialize parent_state or substates state["__dict__"] = state["__dict__"].copy() + if state["__dict__"].get("parent_state") is not None: + # Do not serialize router data in substates (only the root state). + state["__dict__"].pop("router", None) + state["__dict__"].pop("router_data", None) + # Never serialize parent_state or substates. state["__dict__"]["parent_state"] = None state["__dict__"]["substates"] = {} state["__dict__"].pop("_was_touched", None) + # Remove all inherited vars. + for inherited_var_name in self.inherited_vars: + state["__dict__"].pop(inherited_var_name, None) return state + @classmethod + @functools.lru_cache() + def _to_schema(cls) -> str: + """Convert a state to a schema. + + Returns: + The hash of the schema. + """ + + def _field_tuple( + field_name: str, + ) -> Tuple[str, str, Any, Union[bool, None], Any]: + model_field = cls.__fields__[field_name] + return ( + field_name, + model_field.name, + _serialize_type(model_field.type_), + ( + model_field.required + if isinstance(model_field.required, bool) + else None + ), + (model_field.default if is_serializable(model_field.default) else None), + ) + + return md5( + pickle.dumps( + list(sorted(_field_tuple(field_name) for field_name in cls.base_vars)) + ) + ).hexdigest() + def _serialize(self) -> bytes: """Serialize the state for redis. Returns: The serialized state. """ - return pickle.dumps((state_to_schema(self), self)) + return pickle.dumps((self._to_schema(), self)) @classmethod def _deserialize( @@ -1985,7 +2026,7 @@ def _deserialize( (substate_schema, state) = pickle.load(fp) else: raise ValueError("Only one of `data` or `fp` must be provided") - if substate_schema != state_to_schema(state): + if substate_schema != state._to_schema(): raise StateSchemaMismatchError() return state @@ -2620,35 +2661,6 @@ def is_serializable(value: Any) -> bool: return False -def state_to_schema( - state: BaseState, -) -> List[Tuple[str, str, Any, Union[bool, None], Any]]: - """Convert a state to a schema. - - Args: - state: The state to convert to a schema. - - Returns: - The schema. - """ - return list( - sorted( - ( - field_name, - model_field.name, - _serialize_type(model_field.type_), - ( - model_field.required - if isinstance(model_field.required, bool) - else None - ), - (model_field.default if is_serializable(model_field.default) else None), - ) - for field_name, model_field in state.__fields__.items() - ) - ) - - def reset_disk_state_manager(): """Reset the disk state manager.""" states_directory = prerequisites.get_web_dir() / constants.Dirs.STATES diff --git a/tests/integration/test_dynamic_routes.py b/tests/integration/test_dynamic_routes.py index 5ba0b7bda8..35b83790f2 100644 --- a/tests/integration/test_dynamic_routes.py +++ b/tests/integration/test_dynamic_routes.py @@ -41,13 +41,13 @@ def index(): return rx.fragment( rx.input( value=DynamicState.router.session.client_token, - is_read_only=True, + read_only=True, id="token", ), - rx.input(value=rx.State.page_id, is_read_only=True, id="page_id"), # type: ignore + rx.input(value=rx.State.page_id, read_only=True, id="page_id"), # type: ignore rx.input( value=DynamicState.router.page.raw_path, - is_read_only=True, + read_only=True, id="raw_path", ), rx.link("index", href="/", id="link_index"),