From c2c582258ba2a789e024183c57ef65b3993829fb Mon Sep 17 00:00:00 2001 From: Fabien Arcellier Date: Mon, 12 Aug 2024 16:31:43 +0200 Subject: [PATCH] feat: trigger a calculated property on mutation * fix: handle dot separated expression on subscribe mutation --- src/writer/core.py | 44 +++++++++++++++++++++++++++++++++++--- tests/backend/test_core.py | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/src/writer/core.py b/src/writer/core.py index 7c94b19aa..7502582ce 100644 --- a/src/writer/core.py +++ b/src/writer/core.py @@ -155,7 +155,7 @@ def __post_init__(self): if len(self.path) == 0: raise ValueError("path cannot be empty.") - path_parts = self.path.split(".") + path_parts = parse_state_variable_expression(self.path) for part in path_parts: if len(part) == 0: raise ValueError(f"path {self.path} cannot have empty parts.") @@ -169,7 +169,7 @@ def local_path(self) -> str: >>> m.local_path >>> "c" """ - path_parts = self.path.split(".") + path_parts = parse_state_variable_expression(self.path) return path_parts[-1] class StateRecursionWatcher(): @@ -800,6 +800,15 @@ def subscribe_mutation(self, >>> state.subscribe_mutation('a', _increment_counter) >>> state['a'] = 2 # will trigger _increment_counter + subscribe mutation accepts escaped dot expressions to encode key that contains `dot` separator + + >>> def _increment_counter(state, payload, context, session, ui): + >>> state['my_counter'] += 1 + >>> + >>> state = WriterState({'a.b': 1, 'my_counter': 0}) + >>> state.subscribe_mutation('a\.b', _increment_counter) + >>> state['a.b'] = 2 # will trigger _increment_counter + :param path: path of mutation to monitor :param func: handler to call when the path is mutated """ @@ -810,7 +819,7 @@ def subscribe_mutation(self, for p in path_list: state_proxy = self._state_proxy - path_parts = p.split(".") + path_parts = parse_state_variable_expression(p) for i, path_part in enumerate(path_parts): if i == len(path_parts) - 1: local_mutation = MutationSubscription('subscription', p, handler, self) @@ -2355,6 +2364,35 @@ def _clone_mutation_subscriptions(session_state: State, app_state: State, root_s session_state._state_proxy.local_mutation_subscriptions.append(new_mutation_subscription) +def parse_state_variable_expression(p: str): + """ + Parses a state variable expression into a list of parts. + + >>> parse_state_variable_expression("a.b.c") + >>> ["a", "b", "c"] + + >>> parse_state_variable_expression("a\.b.c") + >>> ["a.b", "c"] + """ + parts = [] + it = 0 + last_split = 0 + while it < len(p): + if p[it] == '\\': + it += 2 + elif p[it] == '.': + new_part = p[last_split: it] + parts.append(new_part.replace('\\.', '.')) + + last_split = it + 1 + it += 1 + else: + it += 1 + + new_part = p[last_split: len(p)] + parts.append(new_part.replace('\\.', '.')) + return parts + def writer_event_handler_build_arguments(func: Callable, writer_args: dict) -> List[Any]: """ diff --git a/tests/backend/test_core.py b/tests/backend/test_core.py index 33c7d13ad..b08c68983 100644 --- a/tests/backend/test_core.py +++ b/tests/backend/test_core.py @@ -26,6 +26,7 @@ StateSerialiserException, WriterState, import_failure, + parse_state_variable_expression, ) from writer.core_ui import Component from writer.ss_types import WriterEvent @@ -573,6 +574,30 @@ def cumulative_sum(state: MyState): # Assert assert initial_state['total'] == 4 + def test_subscribe_mutation_should_manage_escaping_in_subscription(self): + """ + Tests that a key that contains a `.` can be used to subscribe to + a mutation using the escape character. + """ + with writer_fixtures.new_app_context(): + # Assign + def cumulative_sum(state): + state['total'] += state['a.b'] + + initial_state = wf.init_state({ + "a.b": 0, + "total": 0 + }) + + initial_state.subscribe_mutation('a\.b', cumulative_sum) + + # Acts + initial_state['a.b'] = 1 + initial_state['a.b'] = 3 + + # Assert + assert initial_state['total'] == 4 + class TestWriterState: # Initialised manually @@ -1667,3 +1692,15 @@ def counter_sum(self) -> int: mutations = state.user_state.get_mutations_as_dict() assert '+counter_sum' in mutations assert mutations['+counter_sum'] == 8 + + +def test_parse_state_variable_expression_should_process_expression(): + """ + Test that the parse_state_variable_expression function will process + the expression correctly + """ + # When + assert parse_state_variable_expression('features') == ['features'] + assert parse_state_variable_expression('features.eyes') == ['features', 'eyes'] + assert parse_state_variable_expression('features\.eyes') == ['features.eyes'] + assert parse_state_variable_expression('features\.eyes.color') == ['features.eyes', 'color']