diff --git a/docs/framework/event-handlers.mdx b/docs/framework/event-handlers.mdx index c91cfadc4..cbb41b5dd 100644 --- a/docs/framework/event-handlers.mdx +++ b/docs/framework/event-handlers.mdx @@ -141,7 +141,8 @@ def hande_click_cleaner(state): You can subscribe to mutations on a specific key in the state. This is useful when you want to trigger a function every time a specific key is mutated. -```python + +```python simple subscription import writer as wf def _increment_counter(state): @@ -153,10 +154,44 @@ state.subscribe_mutation('a', _increment_counter) state['a'] = 2 # trigger _increment_counter mutation ``` -```python -state.subscribe_mutation('a.b', _increment_counter) # subscribe to nested key +```python multiple subscriptions +import writer as wf + +def _increment_counter(state): + state['my_counter'] += 1 + +state = wf.init_state({ + 'title': 'Hello', + 'app': {'title', 'Writer Framework'}, + 'my_counter': 0} +) + state.subscribe_mutation(['title', 'app.title'], _increment_counter) # subscribe to multiple keys + +state['title'] = "Hello Pigeon" # trigger _increment_counter mutation +``` + +```python trigger event handler +import writer as wf + +def _increment_counter(state, context: dict, payload: dict, session: dict, ui: WriterUIManager): + if context['event'] == 'mutation' and context['mutation'] == 'a': + if payload['previous_value'] > payload['new_value']: + state['my_counter'] += 1 + +state = wf.init_state({"a": 1, "my_counter": 0}) +state.subscribe_mutation('a', _increment_counter) + +state['a'] = 2 # increment my_counter +state['a'] = 3 # increment my_counter +state['a'] = 2 # do nothing ``` + + + +`subscribe_mutation` is compatible with event handler signature. It will accept all the arguments +of the event handler (`context`, `payload`, ...). + ## Receiving a payload diff --git a/src/writer/app_runner.py b/src/writer/app_runner.py index a6cb0fa90..3f0bc2aa4 100644 --- a/src/writer/app_runner.py +++ b/src/writer/app_runner.py @@ -19,7 +19,7 @@ from watchdog.observers.polling import PollingObserver from writer import VERSION -from writer.core import EventHandlerRegistry, MiddlewareRegistry, WriterSession +from writer.core import EventHandlerRegistry, MiddlewareRegistry, WriterSession, use_request_context from writer.core_ui import ingest_bmc_component_tree from writer.ss_types import ( AppProcessServerRequest, @@ -232,71 +232,72 @@ def _handle_message(self, session_id: str, request: AppProcessServerRequest) -> """ import writer - session = None - type = request.type - - if type == "sessionInit": - si_req_payload = InitSessionRequestPayload.parse_obj( - request.payload) - return AppProcessServerResponse( - status="ok", - status_message=None, - payload=self._handle_session_init(si_req_payload) - ) - - session = writer.session_manager.get_session(session_id) - if not session: - raise MessageHandlingException("Session not found.") - session.update_last_active_timestamp() - - if type == "checkSession": - return AppProcessServerResponse( - status="ok", - status_message=None, - payload=None - ) - - if type == "event": - ev_req_payload = WriterEvent.parse_obj(request.payload) - return AppProcessServerResponse( - status="ok", - status_message=None, - payload=self._handle_event(session, ev_req_payload) - ) - - if type == "stateEnquiry": - return AppProcessServerResponse( - status="ok", - status_message=None, - payload=self._handle_state_enquiry(session) - ) - - if type == "stateContent": - return AppProcessServerResponse( - status="ok", - status_message=None, - payload=self._handle_state_content(session) - ) - - if type == "setUserinfo": - session.userinfo = request.payload - return AppProcessServerResponse( - status="ok", - status_message=None, - payload=None - ) - - if self.mode == "edit" and type == "componentUpdate": - cu_req_payload = ComponentUpdateRequestPayload.parse_obj( - request.payload) - self._handle_component_update(session, cu_req_payload) - return AppProcessServerResponse( - status="ok", - status_message=None, - payload=None - ) - - raise MessageHandlingException("Invalid event.") + with use_request_context(session_id, request): + session = None + type = request.type + + if type == "sessionInit": + si_req_payload = InitSessionRequestPayload.parse_obj( + request.payload) + return AppProcessServerResponse( + status="ok", + status_message=None, + payload=self._handle_session_init(si_req_payload) + ) + + session = writer.session_manager.get_session(session_id) + if not session: + raise MessageHandlingException("Session not found.") + session.update_last_active_timestamp() + + if type == "checkSession": + return AppProcessServerResponse( + status="ok", + status_message=None, + payload=None + ) + + if type == "event": + ev_req_payload = WriterEvent.parse_obj(request.payload) + return AppProcessServerResponse( + status="ok", + status_message=None, + payload=self._handle_event(session, ev_req_payload) + ) + + if type == "stateEnquiry": + return AppProcessServerResponse( + status="ok", + status_message=None, + payload=self._handle_state_enquiry(session) + ) + + if type == "stateContent": + return AppProcessServerResponse( + status="ok", + status_message=None, + payload=self._handle_state_content(session) + ) + + if type == "setUserinfo": + session.userinfo = request.payload + return AppProcessServerResponse( + status="ok", + status_message=None, + payload=None + ) + + if self.mode == "edit" and type == "componentUpdate": + cu_req_payload = ComponentUpdateRequestPayload.parse_obj( + request.payload) + self._handle_component_update(session, cu_req_payload) + return AppProcessServerResponse( + status="ok", + status_message=None, + payload=None + ) + + raise MessageHandlingException("Invalid event.") def _execute_user_code(self) -> None: """ diff --git a/src/writer/core.py b/src/writer/core.py index c8e4bcb7e..7c94b19aa 100644 --- a/src/writer/core.py +++ b/src/writer/core.py @@ -24,6 +24,7 @@ from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Dict, Generator, @@ -61,7 +62,30 @@ import polars from writer.app_runner import AppProcess + from writer.ss_types import AppProcessServerRequest +@dataclasses.dataclass +class CurrentRequest: + session_id: str + request: 'AppProcessServerRequest' + +_current_request: ContextVar[Optional[CurrentRequest]] = ContextVar("current_request", default=None) + +@contextlib.contextmanager +def use_request_context(session_id: str, request: 'AppProcessServerRequest'): + """ + Context manager to set the current request context. + + >>> session_id = "xxxxxxxxxxxxxxxxxxxxxxxxx" + >>> request = AppProcessServerRequest(type='event', payload=EventPayload(event='my_event')) + >>> with use_request_context(session_id, request): + >>> pass + """ + try: + _current_request.set(CurrentRequest(session_id, request)) + yield + finally: + _current_request.set(None) def get_app_process() -> 'AppProcess': """ @@ -121,8 +145,11 @@ class MutationSubscription: >>> m = MutationSubscription(path="a.c", handler=myhandler) """ + type: Literal['subscription', 'property'] path: str handler: Callable # Handler to execute when mutation happens + state: 'State' + property_name: Optional[str] = None def __post_init__(self): if len(self.path) == 0: @@ -417,21 +444,31 @@ def __setitem__(self, key: str, raw_value: Any) -> None: if not isinstance(key, str): raise ValueError( f"State keys must be strings. Received {str(key)} ({type(key)}).") - previous_value = self.state.get(key) + old_value = self.state.get(key) self.state[key] = raw_value for local_mutation in self.local_mutation_subscriptions: if local_mutation.local_path == key: - from writer.ui import WriterUIManager - - context = {"mutation": local_mutation.path} - payload = { - "mutation_previous_value": previous_value, - "mutation_value": raw_value - } - ui = WriterUIManager() - args = build_writer_func_arguments(local_mutation.handler, {"context": context, "payload": payload, "ui": ui}) - local_mutation.handler(*args) + if local_mutation.type == 'subscription': + context_data = { + "event": "mutation", + "mutation": local_mutation.path + } + payload = { + "previous_value": old_value, + "new_value": raw_value + } + + writer_event_handler_invoke(local_mutation.handler, { + "state": local_mutation.state, + "context": context_data, + "payload": payload, + "session": _event_handler_session_info(), + "ui": _event_handler_ui_manager() + }) + elif local_mutation.type == 'property': + assert local_mutation.property_name is not None + self[local_mutation.property_name] = local_mutation.handler(local_mutation.state) self._apply_raw(f"+{key}") @@ -619,7 +656,7 @@ def bind_annotations_to_state_proxy(cls, klass): class State(metaclass=StateMeta): - def __init__(self, raw_state: Dict[str, Any] | None = None): + def __init__(self, raw_state: Optional[Dict[str, Any]] = None): final_raw_state = raw_state if raw_state is not None else {} self._state_proxy: StateProxy = StateProxy(final_raw_state) @@ -737,12 +774,15 @@ def _set_state_item(self, key: str, value: Any): self._state_proxy[key] = value - def subscribe_mutation(self, path: Union[str, List[str]], handler: Callable[..., None], initial_triggered: bool = False) -> None: + def subscribe_mutation(self, + path: Union[str, List[str]], + handler: Callable[..., Union[None, Awaitable[None]]], + initial_triggered: bool = False) -> None: """ Automatically triggers a handler when a mutation occurs in the state. >>> def _increment_counter(state): - >>> state_proxy['my_counter'] += 1 + >>> state['my_counter'] += 1 >>> >>> state = WriterState({'a': 1, 'c': {'a': 1, 'b': 3}, 'my_counter': 0}) >>> state.subscribe_mutation('a', _increment_counter) @@ -751,6 +791,15 @@ def subscribe_mutation(self, path: Union[str, List[str]], handler: Callable[..., >>> state['a'] = 3 # will trigger _increment_counter >>> state['c']['a'] = 2 # will trigger _increment_counter + subscribe mutation accepts the signature of an event handler. + + >>> def _increment_counter(state, payload, context, session, ui): + >>> state['my_counter'] += 1 + >>> + >>> state = WriterState({'a': 1, 'my_counter': 0}) + >>> state.subscribe_mutation('a', _increment_counter) + >>> state['a'] = 2 # will trigger _increment_counter + :param path: path of mutation to monitor :param func: handler to call when the path is mutated """ @@ -762,21 +811,68 @@ def subscribe_mutation(self, path: Union[str, List[str]], handler: Callable[..., for p in path_list: state_proxy = self._state_proxy path_parts = p.split(".") - final_handler = functools.partial(handler, self) for i, path_part in enumerate(path_parts): if i == len(path_parts) - 1: - local_mutation = MutationSubscription(p, final_handler) + local_mutation = MutationSubscription('subscription', p, handler, self) state_proxy.local_mutation_subscriptions.append(local_mutation) # At startup, the application must be informed of the # existing states. To cause this, we trigger manually # the handler. if initial_triggered is True: - final_handler() + writer_event_handler_invoke(handler, { + "state": self, + "context": {"event": "init"}, + "payload": {}, + "session": {}, + "ui": _event_handler_ui_manager() + }) + elif path_part in state_proxy: state_proxy = state_proxy[path_part] else: - raise ValueError("Mutation subscription failed - {p} not found in state") + raise ValueError(f"Mutation subscription failed - {p} not found in state") + + def calculated_property(self, + property_name: str, + path: Union[str, List[str]], + handler: Callable[..., Union[None, Awaitable[None]]]) -> None: + """ + Update a calculated property when a mutation triggers + + This method is dedicated to be used through a calculated property. It is not + recommended to invoke it directly. + + >>> class MyState(State): + >>> title: str + >>> + >>> wf.property('title') + >>> def title_upper(self): + >>> return self.title.upper() + + Usage + ===== + + >>> state = wf.init_state({'title': 'hello world'}) + >>> state.calculated_property('title_upper', 'title', lambda state: state.title.upper()) + """ + if isinstance(path, str): + path_list = [path] + else: + path_list = path + + for p in path_list: + state_proxy = self._state_proxy + path_parts = p.split(".") + for i, path_part in enumerate(path_parts): + if i == len(path_parts) - 1: + local_mutation = MutationSubscription('property', p, handler, self, property_name) + state_proxy.local_mutation_subscriptions.append(local_mutation) + state_proxy[property_name] = handler(self) + elif path_part in state_proxy: + state_proxy = state_proxy[path_part] + else: + raise ValueError(f"Property subscription failed - {p} not found in state") class WriterState(State): @@ -824,7 +920,10 @@ def get_clone(self) -> 'WriterState': "The state may contain unpickable objects, such as modules.", traceback.format_exc()) return substitute_state - return self.__class__(cloned_user_state, cloned_mail) + + cloned_state = self.__class__(cloned_user_state, cloned_mail) + _clone_mutation_subscriptions(cloned_state, self) + return cloned_state def add_mail(self, type: str, payload: Any) -> None: mail_item = { @@ -969,7 +1068,7 @@ def __init__(self, middleware: Callable): @contextlib.contextmanager def execute(self, args: dict): - middleware_args = build_writer_func_arguments(self.middleware, args) + middleware_args = writer_event_handler_build_arguments(self.middleware, args) it = self.middleware(*middleware_args) try: yield from it @@ -993,7 +1092,7 @@ def executors(self) -> List[MiddlewareExecutor]: Retrieves middlewares prepared for execution >>> executors = middleware_registry.executors() - >>> result = handle_with_middlewares_executor(executors, lambda state: pass, {'state': {}, 'payload': {}}) + >>> result = writer_event_handler_invoke_with_middlewares(executors, lambda state: pass, {'state': {}, 'payload': {}}) """ return self.registry @@ -1596,21 +1695,14 @@ def _call_handler_callable( raise ValueError(f"""Invalid handler. Couldn't find the handler "{ handler }".""") # Preparation of arguments - from writer.ui import WriterUIManager - context_data = self.evaluator.get_context_data(instance_path) context_data['event'] = event_type writer_args = { 'state': self.session_state, 'payload': payload, 'context': context_data, - 'session': { - 'id': self.session.session_id, - 'cookies': self.session.cookies, - 'headers': self.session.headers, - 'userinfo': self.session.userinfo or {} - }, - 'ui': WriterUIManager() + 'session':_event_handler_session_info(), + 'ui': _event_handler_ui_manager() } # Invocation of handler @@ -1620,7 +1712,7 @@ def _call_handler_callable( contextlib.redirect_stdout(io.StringIO()) as f: middlewares_executors = current_app_process.middleware_registry.executors() - result = handle_with_middlewares_executor(middlewares_executors, callable_handler, writer_args) + result = writer_event_handler_invoke_with_middlewares(middlewares_executors, callable_handler, writer_args) captured_stdout = f.getvalue() if captured_stdout: @@ -2166,7 +2258,7 @@ class Property(): def __init__(self, func): self.func = func - self.instances = set() + self.initialized = False self.property_name = None def __call__(self, *args, **kwargs): @@ -2174,7 +2266,7 @@ def __call__(self, *args, **kwargs): def __set_name__(self, owner: Type[State], name: str): """ - Saves the calculated properties when loading the class. + Saves the calculated properties when loading a State class. """ if owner not in calculated_properties_per_state_type: calculated_properties_per_state_type[owner] = [] @@ -2186,13 +2278,14 @@ def __get__(self, instance: State, cls): """ This mechanism retrieves the property instance. """ - property_name = self.property_name - if instance not in self.instances: - def calculated_property_handler(state): - instance._state_proxy[property_name] = self.func(state) + args = inspect.getfullargspec(self.func) + if len(args.args) > 1: + logging.warning(f"Wrong signature for calculated property '{instance.__class__.__name__}:{self.property_name}'. It must declare only self argument.") + return None - instance.subscribe_mutation(path, calculated_property_handler, initial_triggered=True) - self.instances.add(instance) + if self.initialized is False: + instance.calculated_property(property_name=self.property_name, path=path, handler=self.func) + self.initialized = True return self.func(instance) @@ -2209,6 +2302,25 @@ def wrapped(*args, **kwargs): session_manager.add_verifier(func) return wrapped + +def get_session() -> Optional[WriterSession]: + """ + Retrieves the current session. + + This function works exclusively in the context of a request. + """ + req = _current_request.get() + if req is None: + return None + + session_id = req.session_id + session = session_manager.get_session(session_id) + if not session: + return None + + return session + + def reset_base_component_tree() -> None: """ Reset the base component tree to zero @@ -2218,8 +2330,56 @@ def reset_base_component_tree() -> None: global base_component_tree base_component_tree = core_ui.build_base_component_tree() +def _clone_mutation_subscriptions(session_state: State, app_state: State, root_state: Optional['State'] = None) -> None: + """ + clone subscriptions on mutations between the initial state of the application and the state created for the session + + >>> state = wf.init_state({"counter": 0}) + >>> state.subscribe_mutation("counter", lambda state: print(state["counter"])) -def handler_executor(callable_handler: Callable, writer_args: dict) -> Any: + >>> session_state = state.get_clone() + + :param session_state: + :param app_state: + :param root_state: + """ + state_proxy_app = app_state._state_proxy + state_proxy_session = session_state._state_proxy + + state_proxy_session.local_mutation_subscriptions = [] + + _root_state = root_state if root_state is not None else session_state + for mutation_subscription in state_proxy_app.local_mutation_subscriptions: + new_mutation_subscription = copy.copy(mutation_subscription) + new_mutation_subscription.state = _root_state if new_mutation_subscription.type == "subscription" else session_state + session_state._state_proxy.local_mutation_subscriptions.append(new_mutation_subscription) + + + +def writer_event_handler_build_arguments(func: Callable, writer_args: dict) -> List[Any]: + """ + Constructs the list of arguments based on the signature of the function + which can be a handler or middleware. + + >>> def my_event_handler(state, context): + >>> yield + + >>> args = writer_event_handler_build_arguments(my_event_handler, {'state': {}, 'payload': {}, 'context': {"target": '11'}, 'session': None, 'ui': None}) + >>> [{}, {"target": '11'}] + + :param func: the function that will be called + :param writer_args: the possible arguments in writer (state, payload, ...) + """ + handler_args = inspect.getfullargspec(func).args + func_args = [] + for arg in handler_args: + if arg in writer_args: + func_args.append(writer_args[arg]) + + return func_args + + +def writer_event_handler_invoke(callable_handler: Callable, writer_args: dict) -> Any: """ Runs a handler based on its signature. @@ -2229,13 +2389,13 @@ def handler_executor(callable_handler: Callable, writer_args: dict) -> Any: >>> def my_handler(state): >>> state['a'] = 2 >>> - >>> handler_executor(my_handler, {'state': {'a': 1}, 'payload': None, 'context': None, 'session': None, 'ui': None}) + >>> writer_event_handler_invoke(my_handler, {'state': {'a': 1}, 'payload': None, 'context': None, 'session': None, 'ui': None}) """ is_async_handler = inspect.iscoroutinefunction(callable_handler) if (not callable(callable_handler) and not is_async_handler): raise ValueError("Invalid handler. The handler isn't a callable object.") - handler_args = build_writer_func_arguments(callable_handler, writer_args) + handler_args = writer_event_handler_build_arguments(callable_handler, writer_args) if is_async_handler: async_wrapper = _async_wrapper_internal(callable_handler, handler_args) @@ -2245,7 +2405,7 @@ def handler_executor(callable_handler: Callable, writer_args: dict) -> Any: return result -def handle_with_middlewares_executor(middlewares_executors: List[MiddlewareExecutor], callable_handler: Callable, writer_args: dict) -> Any: +def writer_event_handler_invoke_with_middlewares(middlewares_executors: List[MiddlewareExecutor], callable_handler: Callable, writer_args: dict) -> Any: """ Runs the middlewares then the handler. This function allows you to manage exceptions that are triggered in middleware @@ -2257,36 +2417,14 @@ def handle_with_middlewares_executor(middlewares_executors: List[MiddlewareExecu >>> yield >>> executor = MiddlewareExecutor(my_middleware, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None}) - >>> handle_with_middlewares_executor([executor], my_handler, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None} + >>> writer_event_handler_invoke_with_middlewares([executor], my_handler, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None} """ if len(middlewares_executors) == 0: - return handler_executor(callable_handler, writer_args) + return writer_event_handler_invoke(callable_handler, writer_args) else: executor = middlewares_executors[0] with executor.execute(writer_args): - return handle_with_middlewares_executor(middlewares_executors[1:], callable_handler, writer_args) - -def build_writer_func_arguments(func: Callable, writer_args: dict) -> List[Any]: - """ - Constructs the list of arguments based on the signature of the function - which can be a handler or middleware. - - >>> def my_event_handler(state, context): - >>> yield - - >>> args = build_writer_func_arguments(my_event_handler, {'state': {}, 'payload': {}, 'context': {"target": '11'}, 'session': None, 'ui': None}) - >>> [{}, {"target": '11'}] - - :param func: the function that will be called - :param writer_args: the possible arguments in writer (state, payload, ...) - """ - handler_args = inspect.getfullargspec(func).args - func_args = [] - for arg in handler_args: - if arg in writer_args: - func_args.append(writer_args[arg]) - - return func_args + return writer_event_handler_invoke_with_middlewares(middlewares_executors[1:], callable_handler, writer_args) async def _async_wrapper_internal(callable_handler: Callable, arg_values: List[Any]) -> Any: @@ -2331,6 +2469,27 @@ def _assert_record_match_list_of_records(df: List[Dict[str, Any]], record: Dict[ if columns != columns_record: raise ValueError(f"Columns mismatch. Expected {columns}, got {columns_record}") +def _event_handler_session_info() -> Dict[str, Any]: + """ + Returns the session information for the current event handler. + + This information is exposed in the session parameter of a handler + + """ + current_session = get_session() + session_info: Dict[str, Any] = {} + if current_session is not None: + session_info['id'] = current_session.session_id + session_info['cookies'] = current_session.cookies + session_info['headers'] = current_session.headers + session_info['userinfo'] = current_session.userinfo or {} + + return session_info + +def _event_handler_ui_manager(): + from writer.ui import WriterUIManager + return WriterUIManager() + def _split_record_as_pandas_record_and_index(param: dict, index_columns: list) -> Tuple[dict, tuple]: """ diff --git a/tests/backend/test_core.py b/tests/backend/test_core.py index 18231b383..33c7d13ad 100644 --- a/tests/backend/test_core.py +++ b/tests/backend/test_core.py @@ -477,6 +477,28 @@ def _increment_counter2(state): assert mutations['+my_counter'] == 1 assert mutations['+my_counter2'] == 1 + def test_subscribe_mutation_should_work_with_async_event_handler(self): + """ + Tests that multiple handlers can be triggered in cascade if one of them modifies a value + that is listened to by another handler during a mutation. + """ + # Assign + async def _increment_counter(state): + state['my_counter'] += 1 + + _state = WriterState({"a": 1, "my_counter": 0}) + _state.user_state.get_mutations_as_dict() + + # Acts + _state.subscribe_mutation('a', _increment_counter) + _state['a'] = 2 + + # Assert + assert _state['my_counter'] == 1 + + mutations = _state.user_state.get_mutations_as_dict() + assert mutations['+my_counter'] == 1 + def test_subscribe_mutation_should_raise_error_on_infinite_cascading(self): """ Tests that an infinite recursive loop is detected and an error is raised if mutations cascade @@ -512,9 +534,9 @@ def _increment_counter(state, payload, context: dict, ui): state['my_counter'] += 1 # Assert - assert payload['mutation_previous_value'] == 1 - assert payload['mutation_value'] == 2 assert context['mutation'] == 'a' + assert payload['previous_value'] == 1 + assert payload['new_value'] == 2 _state = WriterState({"a": 1, "my_counter": 0}) _state.user_state.get_mutations_as_dict()