diff --git a/reflex/state.py b/reflex/state.py index 5798564fa4..96435dbaf9 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2711,34 +2711,24 @@ def token_path(self, token: str) -> Path: self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl" ).absolute() - async def load_state(self, token: str, root_state: BaseState) -> BaseState: + async def load_state(self, token: str) -> BaseState | None: """Load a state object based on the provided token. Args: token: The token used to identify the state object. - root_state: The root state object. Returns: - The loaded state object. + The loaded state object or None. """ - if token in self.states: - return self.states[token] - - client_token, substate_address = _split_substate_key(token) - token_path = self.token_path(token) if token_path.exists(): try: with token_path.open(mode="rb") as file: - substate = BaseState._deserialize(fp=file) - await self.populate_substates(client_token, substate, root_state) - return substate + return BaseState._deserialize(fp=file) except Exception: pass - return root_state.get_substate(substate_address.split(".")[1:]) - async def populate_substates( self, client_token: str, state: BaseState, root_state: BaseState ): @@ -2752,10 +2742,13 @@ async def populate_substates( for substate in state.get_substates(): substate_token = _substate_key(client_token, substate) - substate = await self.load_state(substate_token, root_state) + instance = await self.load_state(substate_token) + if instance is None: + instance = await root_state.get_state(substate) + state.substates[substate.get_name()] = instance + instance.parent_state = state - state.substates[substate.get_name()] = substate - substate.parent_state = state + await self.populate_substates(client_token, instance, root_state) @override async def get_state( @@ -2770,15 +2763,24 @@ async def get_state( Returns: The state for the token. """ - client_token, substate_address = _split_substate_key(token) - - root_state_token = _substate_key(client_token, substate_address.split(".")[0]) - root_state = self.states.get(root_state_token) + client_token = _split_substate_key(token)[0] + root_state = self.states.get(client_token) + if root_state is not None: + # Retrieved state from memory. + return root_state + + # Deserialize root state from disk. + root_state = await self.load_state(_substate_key(client_token, self.state)) + # Create a new root state tree with all substates instantiated. + fresh_root_state = self.state(_reflex_internal_init=True) if root_state is None: - # Create a new root state which will be persisted in the next set_state call. - root_state = self.state(_reflex_internal_init=True) - - return await self.load_state(root_state_token, root_state) + root_state = fresh_root_state + else: + # Ensure all substates exist, even if they were not serialized previously. + root_state.substates = fresh_root_state.substates + self.states[client_token] = root_state + await self.populate_substates(client_token, root_state, root_state) + return root_state async def set_state_for_substate(self, client_token: str, substate: BaseState): """Set the state for a substate. @@ -2789,12 +2791,12 @@ async def set_state_for_substate(self, client_token: str, substate: BaseState): """ substate_token = _substate_key(client_token, substate) - self.states[substate_token] = substate - - state_dilled = substate._serialize() - if not self.states_directory.exists(): - self.states_directory.mkdir(parents=True, exist_ok=True) - self.token_path(substate_token).write_bytes(state_dilled) + if substate._get_was_touched(): + substate._was_touched = False # Reset the touched flag after serializing. + pickle_state = substate._serialize() + if not self.states_directory.exists(): + self.states_directory.mkdir(parents=True, exist_ok=True) + self.token_path(substate_token).write_bytes(pickle_state) for substate_substate in substate.substates.values(): await self.set_state_for_substate(client_token, substate_substate) diff --git a/reflex/testing.py b/reflex/testing.py index bdbd3dc948..7ea524f1c2 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -292,8 +292,6 @@ def _initialize_app(self): if isinstance(self.app_instance._state_manager, StateManagerRedis): # Create our own redis connection for testing. self.state_manager = StateManagerRedis.create(self.app_instance.state) - elif isinstance(self.app_instance._state_manager, StateManagerDisk): - self.state_manager = StateManagerDisk.create(self.app_instance.state) else: self.state_manager = self.app_instance._state_manager diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 5bfac76282..4e783b532f 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1884,11 +1884,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): async with sp: assert sp._self_actx is not None assert sp._self_mutable # proxy is mutable inside context - if isinstance(mock_app.state_manager, StateManagerMemory): + if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): # For in-process store, only one instance of the state exists assert sp.__wrapped__ is grandchild_state else: - # When redis or disk is used, a new+updated instance is assigned to the proxy + # When redis is used, a new+updated instance is assigned to the proxy assert sp.__wrapped__ is not grandchild_state sp.value2 = "42" assert not sp._self_mutable # proxy is not mutable after exiting context @@ -1899,7 +1899,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): gotten_state = await mock_app.state_manager.get_state( _substate_key(grandchild_state.router.session.client_token, grandchild_state) ) - if isinstance(mock_app.state_manager, StateManagerMemory): + if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): # For in-process store, only one instance of the state exists assert gotten_state is parent_state else: @@ -2922,7 +2922,7 @@ async def test_get_state(mock_app: rx.App, token: str): _substate_key(token, ChildState2) ) assert isinstance(new_test_state, TestState) - if isinstance(mock_app.state_manager, StateManagerMemory): + if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): # In memory, it's the same instance assert new_test_state is test_state test_state._clean() @@ -2932,15 +2932,6 @@ async def test_get_state(mock_app: rx.App, token: str): ChildState2.get_name(), ChildState3.get_name(), ) - elif isinstance(mock_app.state_manager, StateManagerDisk): - # On disk, it's a new instance - assert new_test_state is not test_state - # All substates are available - assert tuple(sorted(new_test_state.substates)) == ( - ChildState.get_name(), - ChildState2.get_name(), - ChildState3.get_name(), - ) else: # With redis, we get a whole new instance assert new_test_state is not test_state