Skip to content

Commit

Permalink
Optimize StateManagerDisk (#4056)
Browse files Browse the repository at this point in the history
* Simplify StateManagerDisk implementation

* Act more like the memory state manager and only track the root state in self.states
* .load_state always loads a single state or returns None
* .populate_states is the new entry point in loading from disk and it only occurs
  when the root state is not known
* much fast

* StateManagerDisk now acts much more like StateManagerMemory

Treat StateManagerDisk like StateManagerMemory for AppHarness

* Handle root_state deserialized from disk

In this case, we need to initialize the whole state tree, so any non-persistent
states will still get default values, whereas on-disk states will overwrite the
defaults.

* Cache root_state under client_token for StateManagerMemory compatibility

Mainly this just makes it easier for us to write tests that work against either
Disk or Memory state managers.
  • Loading branch information
masenf authored Oct 7, 2024
1 parent 1f3be63 commit aa69234
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 45 deletions.
62 changes: 32 additions & 30 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,34 +2711,24 @@ def token_path(self, token: str) -> Path:
self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
).absolute()

async def load_state(self, token: str, root_state: BaseState) -> BaseState:
async def load_state(self, token: str) -> BaseState | None:
"""Load a state object based on the provided token.
Args:
token: The token used to identify the state object.
root_state: The root state object.
Returns:
The loaded state object.
The loaded state object or None.
"""
if token in self.states:
return self.states[token]

client_token, substate_address = _split_substate_key(token)

token_path = self.token_path(token)

if token_path.exists():
try:
with token_path.open(mode="rb") as file:
substate = BaseState._deserialize(fp=file)
await self.populate_substates(client_token, substate, root_state)
return substate
return BaseState._deserialize(fp=file)
except Exception:
pass

return root_state.get_substate(substate_address.split(".")[1:])

async def populate_substates(
self, client_token: str, state: BaseState, root_state: BaseState
):
Expand All @@ -2752,10 +2742,13 @@ async def populate_substates(
for substate in state.get_substates():
substate_token = _substate_key(client_token, substate)

substate = await self.load_state(substate_token, root_state)
instance = await self.load_state(substate_token)
if instance is None:
instance = await root_state.get_state(substate)
state.substates[substate.get_name()] = instance
instance.parent_state = state

state.substates[substate.get_name()] = substate
substate.parent_state = state
await self.populate_substates(client_token, instance, root_state)

@override
async def get_state(
Expand All @@ -2770,15 +2763,24 @@ async def get_state(
Returns:
The state for the token.
"""
client_token, substate_address = _split_substate_key(token)

root_state_token = _substate_key(client_token, substate_address.split(".")[0])
root_state = self.states.get(root_state_token)
client_token = _split_substate_key(token)[0]
root_state = self.states.get(client_token)
if root_state is not None:
# Retrieved state from memory.
return root_state

# Deserialize root state from disk.
root_state = await self.load_state(_substate_key(client_token, self.state))
# Create a new root state tree with all substates instantiated.
fresh_root_state = self.state(_reflex_internal_init=True)
if root_state is None:
# Create a new root state which will be persisted in the next set_state call.
root_state = self.state(_reflex_internal_init=True)

return await self.load_state(root_state_token, root_state)
root_state = fresh_root_state
else:
# Ensure all substates exist, even if they were not serialized previously.
root_state.substates = fresh_root_state.substates
self.states[client_token] = root_state
await self.populate_substates(client_token, root_state, root_state)
return root_state

async def set_state_for_substate(self, client_token: str, substate: BaseState):
"""Set the state for a substate.
Expand All @@ -2789,12 +2791,12 @@ async def set_state_for_substate(self, client_token: str, substate: BaseState):
"""
substate_token = _substate_key(client_token, substate)

self.states[substate_token] = substate

state_dilled = substate._serialize()
if not self.states_directory.exists():
self.states_directory.mkdir(parents=True, exist_ok=True)
self.token_path(substate_token).write_bytes(state_dilled)
if substate._get_was_touched():
substate._was_touched = False # Reset the touched flag after serializing.
pickle_state = substate._serialize()
if not self.states_directory.exists():
self.states_directory.mkdir(parents=True, exist_ok=True)
self.token_path(substate_token).write_bytes(pickle_state)

for substate_substate in substate.substates.values():
await self.set_state_for_substate(client_token, substate_substate)
Expand Down
2 changes: 0 additions & 2 deletions reflex/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,6 @@ def _initialize_app(self):
if isinstance(self.app_instance._state_manager, StateManagerRedis):
# Create our own redis connection for testing.
self.state_manager = StateManagerRedis.create(self.app_instance.state)
elif isinstance(self.app_instance._state_manager, StateManagerDisk):
self.state_manager = StateManagerDisk.create(self.app_instance.state)
else:
self.state_manager = self.app_instance._state_manager

Expand Down
17 changes: 4 additions & 13 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,11 +1884,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
async with sp:
assert sp._self_actx is not None
assert sp._self_mutable # proxy is mutable inside context
if isinstance(mock_app.state_manager, StateManagerMemory):
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
# For in-process store, only one instance of the state exists
assert sp.__wrapped__ is grandchild_state
else:
# When redis or disk is used, a new+updated instance is assigned to the proxy
# When redis is used, a new+updated instance is assigned to the proxy
assert sp.__wrapped__ is not grandchild_state
sp.value2 = "42"
assert not sp._self_mutable # proxy is not mutable after exiting context
Expand All @@ -1899,7 +1899,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
gotten_state = await mock_app.state_manager.get_state(
_substate_key(grandchild_state.router.session.client_token, grandchild_state)
)
if isinstance(mock_app.state_manager, StateManagerMemory):
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
# For in-process store, only one instance of the state exists
assert gotten_state is parent_state
else:
Expand Down Expand Up @@ -2922,7 +2922,7 @@ async def test_get_state(mock_app: rx.App, token: str):
_substate_key(token, ChildState2)
)
assert isinstance(new_test_state, TestState)
if isinstance(mock_app.state_manager, StateManagerMemory):
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
# In memory, it's the same instance
assert new_test_state is test_state
test_state._clean()
Expand All @@ -2932,15 +2932,6 @@ async def test_get_state(mock_app: rx.App, token: str):
ChildState2.get_name(),
ChildState3.get_name(),
)
elif isinstance(mock_app.state_manager, StateManagerDisk):
# On disk, it's a new instance
assert new_test_state is not test_state
# All substates are available
assert tuple(sorted(new_test_state.substates)) == (
ChildState.get_name(),
ChildState2.get_name(),
ChildState3.get_name(),
)
else:
# With redis, we get a whole new instance
assert new_test_state is not test_state
Expand Down

0 comments on commit aa69234

Please sign in to comment.