Skip to content

Commit

Permalink
feat: trigger a calculated property on mutation
Browse files Browse the repository at this point in the history
* fix: handle dot separated expression on subscribe mutation
  • Loading branch information
FabienArcellier committed Aug 17, 2024
1 parent 33f2d8e commit c2c5822
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
44 changes: 41 additions & 3 deletions src/writer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __post_init__(self):
if len(self.path) == 0:
raise ValueError("path cannot be empty.")

path_parts = self.path.split(".")
path_parts = parse_state_variable_expression(self.path)
for part in path_parts:
if len(part) == 0:
raise ValueError(f"path {self.path} cannot have empty parts.")
Expand All @@ -169,7 +169,7 @@ def local_path(self) -> str:
>>> m.local_path
>>> "c"
"""
path_parts = self.path.split(".")
path_parts = parse_state_variable_expression(self.path)
return path_parts[-1]

class StateRecursionWatcher():
Expand Down Expand Up @@ -800,6 +800,15 @@ def subscribe_mutation(self,
>>> state.subscribe_mutation('a', _increment_counter)
>>> state['a'] = 2 # will trigger _increment_counter
subscribe mutation accepts escaped dot expressions to encode key that contains `dot` separator
>>> def _increment_counter(state, payload, context, session, ui):
>>> state['my_counter'] += 1
>>>
>>> state = WriterState({'a.b': 1, 'my_counter': 0})
>>> state.subscribe_mutation('a\.b', _increment_counter)
>>> state['a.b'] = 2 # will trigger _increment_counter
:param path: path of mutation to monitor
:param func: handler to call when the path is mutated
"""
Expand All @@ -810,7 +819,7 @@ def subscribe_mutation(self,

for p in path_list:
state_proxy = self._state_proxy
path_parts = p.split(".")
path_parts = parse_state_variable_expression(p)
for i, path_part in enumerate(path_parts):
if i == len(path_parts) - 1:
local_mutation = MutationSubscription('subscription', p, handler, self)
Expand Down Expand Up @@ -2355,6 +2364,35 @@ def _clone_mutation_subscriptions(session_state: State, app_state: State, root_s
session_state._state_proxy.local_mutation_subscriptions.append(new_mutation_subscription)


def parse_state_variable_expression(p: str):
"""
Parses a state variable expression into a list of parts.
>>> parse_state_variable_expression("a.b.c")
>>> ["a", "b", "c"]
>>> parse_state_variable_expression("a\.b.c")
>>> ["a.b", "c"]
"""
parts = []
it = 0
last_split = 0
while it < len(p):
if p[it] == '\\':
it += 2
elif p[it] == '.':
new_part = p[last_split: it]
parts.append(new_part.replace('\\.', '.'))

last_split = it + 1
it += 1
else:
it += 1

new_part = p[last_split: len(p)]
parts.append(new_part.replace('\\.', '.'))
return parts


def writer_event_handler_build_arguments(func: Callable, writer_args: dict) -> List[Any]:
"""
Expand Down
37 changes: 37 additions & 0 deletions tests/backend/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
StateSerialiserException,
WriterState,
import_failure,
parse_state_variable_expression,
)
from writer.core_ui import Component
from writer.ss_types import WriterEvent
Expand Down Expand Up @@ -573,6 +574,30 @@ def cumulative_sum(state: MyState):
# Assert
assert initial_state['total'] == 4

def test_subscribe_mutation_should_manage_escaping_in_subscription(self):
"""
Tests that a key that contains a `.` can be used to subscribe to
a mutation using the escape character.
"""
with writer_fixtures.new_app_context():
# Assign
def cumulative_sum(state):
state['total'] += state['a.b']

initial_state = wf.init_state({
"a.b": 0,
"total": 0
})

initial_state.subscribe_mutation('a\.b', cumulative_sum)

# Acts
initial_state['a.b'] = 1
initial_state['a.b'] = 3

# Assert
assert initial_state['total'] == 4

class TestWriterState:

# Initialised manually
Expand Down Expand Up @@ -1667,3 +1692,15 @@ def counter_sum(self) -> int:
mutations = state.user_state.get_mutations_as_dict()
assert '+counter_sum' in mutations
assert mutations['+counter_sum'] == 8


def test_parse_state_variable_expression_should_process_expression():
"""
Test that the parse_state_variable_expression function will process
the expression correctly
"""
# When
assert parse_state_variable_expression('features') == ['features']
assert parse_state_variable_expression('features.eyes') == ['features', 'eyes']
assert parse_state_variable_expression('features\.eyes') == ['features.eyes']
assert parse_state_variable_expression('features\.eyes.color') == ['features.eyes', 'color']

0 comments on commit c2c5822

Please sign in to comment.