diff --git a/docs/framework/event-handlers.mdx b/docs/framework/event-handlers.mdx index 8353d2510..6617c8357 100644 --- a/docs/framework/event-handlers.mdx +++ b/docs/framework/event-handlers.mdx @@ -172,6 +172,41 @@ def evaluate(state, payload): Take into account that globals apply to all users. If you need to store data that's only relevant to a particular user, use application state. +## Middlewares + +Middlewares are functions that run before and after every event handler. +They can be used to perform tasks such as logging, error handling, session management, or modifying the state. + +```py +import writer as wf + +@wf.middleware() +def middleware_before(state, payload, context): + print("Middleware before event handler") + state['running'] += 1 + yield + print("Middleware after event handler") + state['running'] -= 1 +``` + +A middleware receives the same parameters as an event handler. + +A middleware can be used to handle exceptions that happens in event handlers. + +```py +import writer as wf + +@wf.middleware() +def middleware_before(state): + try: + yield + except Exception as e: + state['error_counter'] += 1 + state['last_error'] = str() + finally: + pass +``` + ## Standard output The standard output of an app is captured and shown in the code editor's log. You can use the standard `print` function to output results. @@ -232,8 +267,9 @@ You can use any awaitable object within an async event handler. This includes th ## Context -The `context` argument provides additional information about the event. For example, if the event -was triggered by a _Button_, the `context` will include target field that contains the id of the button. +The `context` argument provides additional information about the event. + +The context provide the id of component that trigger the event in `target` field. ```py def handle_click(state, context: dict): @@ -241,6 +277,15 @@ def handle_click(state, context: dict): state["last_source_of_click"] = last_source_of_click ``` +The context provides the event triggered in the `event` field. + +```py +def handle_click(state, context: dict): + event_type = context['event'] + if event_type == 'click': + state["last_event"] = 'Click' +``` + The repeater components have additional fields in the context, such as defined in `keyVariable` and `valueVariable`. ```py diff --git a/src/ui/src/builder/BuilderFieldsPadding.vue b/src/ui/src/builder/BuilderFieldsPadding.vue index 53b2967a5..5c87cf61f 100644 --- a/src/ui/src/builder/BuilderFieldsPadding.vue +++ b/src/ui/src/builder/BuilderFieldsPadding.vue @@ -42,7 +42,7 @@ @select="handleInputSelect" />
- padding + All
- padding + X px
- padding + Y
- padding + Left px
- padding + Right px
- padding + Top px
- padding + Bottom { flex-direction: row; gap: 8px; padding: 8px; - align-items: center; + align-items: baseline; } .row input { width: calc(100% - 32px) !important; + text-align: right; } diff --git a/src/ui/src/builder/BuilderSelect.vue b/src/ui/src/builder/BuilderSelect.vue index 542b941f6..19439813c 100644 --- a/src/ui/src/builder/BuilderSelect.vue +++ b/src/ui/src/builder/BuilderSelect.vue @@ -152,6 +152,7 @@ const select = (event) => { display: block; padding: 8px; font-weight: 400; + font-size: 0.75rem; color: #000000e6; cursor: pointer; transition: all 0.2s; @@ -182,4 +183,8 @@ const select = (event) => { flex-direction: row; gap: 8px; } + +.selectContent { + font-size: 0.75rem; +} diff --git a/src/writer/__init__.py b/src/writer/__init__.py index 3a2ac7863..b93a03f1d 100644 --- a/src/writer/__init__.py +++ b/src/writer/__init__.py @@ -10,6 +10,7 @@ State, WriterState, base_component_tree, + get_app_process, initial_state, new_initial_state, session_manager, @@ -128,3 +129,33 @@ def init_handlers(handler_modules: Union[List[ModuleType], ModuleType]): for module in handler_modules: handler_registry.register_module(module) + + +def middleware(): + """ + A "middleware" is a function that works with every event handler before it is processed and also before returning it. + + >>> import writer as wf + >>> + >>> @wf.middleware() + >>> def my_middleware(state): + >>> state['processing'] += 1 + >>> yield + >>> state['processing'] -= 1 + + Middleware accepts the same arguments as an event handler. + + >>> import writer as wf + >>> + >>> @wf.middleware() + >>> def my_middleware(state, payload, session): + >>> state['processing'] += 1 + >>> yield + >>> state['processing'] -= 1 + """ + def inner(func): + _app_process = get_app_process() + _app_process.middleware_registry.register(func) + + + return inner diff --git a/src/writer/app_runner.py b/src/writer/app_runner.py index e4562d565..a6cb0fa90 100644 --- a/src/writer/app_runner.py +++ b/src/writer/app_runner.py @@ -19,7 +19,7 @@ from watchdog.observers.polling import PollingObserver from writer import VERSION -from writer.core import EventHandlerRegistry, WriterSession +from writer.core import EventHandlerRegistry, MiddlewareRegistry, WriterSession from writer.core_ui import ingest_bmc_component_tree from writer.ss_types import ( AppProcessServerRequest, @@ -33,6 +33,8 @@ InitSessionRequest, InitSessionRequestPayload, InitSessionResponsePayload, + StateContentRequest, + StateContentResponsePayload, StateEnquiryRequest, StateEnquiryResponsePayload, WriterEvent, @@ -96,6 +98,7 @@ def __init__(self, self.is_app_process_server_failed = is_app_process_server_failed self.logger = logging.getLogger("app") self.handler_registry = EventHandlerRegistry() + self.middleware_registry = MiddlewareRegistry() def _load_module(self) -> ModuleType: @@ -204,6 +207,19 @@ def _handle_state_enquiry(self, session: WriterSession) -> StateEnquiryResponseP session.session_state.clear_mail() return res_payload + + def _handle_state_content(self, session: WriterSession) -> StateContentResponsePayload: + serialized_state = {} + try: + serialized_state = session.session_state.user_state.to_raw_state() + except BaseException: + import traceback as tb + session.session_state.add_log_entry("error", + "Serialisation Error", + "An exception was raised during serialisation.", + tb.format_exc()) + + return StateContentResponsePayload(state=serialized_state) def _handle_component_update(self, session: WriterSession, payload: ComponentUpdateRequestPayload) -> None: import writer @@ -255,6 +271,13 @@ def _handle_message(self, session_id: str, request: AppProcessServerRequest) -> payload=self._handle_state_enquiry(session) ) + if type == "stateContent": + return AppProcessServerResponse( + status="ok", + status_message=None, + payload=self._handle_state_content(session) + ) + if type == "setUserinfo": session.userinfo = request.payload return AppProcessServerResponse( @@ -713,6 +736,16 @@ async def handle_state_enquiry(self, session_id: str) -> AppProcessServerRespons type="stateEnquiry" )) + async def handle_state_content(self, session_id: str) -> AppProcessServerResponse: + """ + This method returns the complete status of the application. + + It is only accessible through tests + """ + return await self.dispatch_message(session_id, StateContentRequest( + type="stateContent" + )) + def save_code(self, session_id: str, code: str) -> None: if self.mode != "edit": raise PermissionError("Cannot save code in non-edit mode.") diff --git a/src/writer/core.py b/src/writer/core.py index 1ad78ee44..c72a755f7 100644 --- a/src/writer/core.py +++ b/src/writer/core.py @@ -36,8 +36,10 @@ ) from writer import core_ui +from writer.core_ui import Component from writer.ss_types import ( InstancePath, + InstancePathItem, Readable, WriterEvent, WriterEventResult, @@ -49,14 +51,18 @@ def get_app_process() -> 'AppProcess': + """ + Retrieves the Writer Framework process context. + + >>> _current_process = get_app_process() + >>> _current_process.bmc_components # get the component tree + """ from writer.app_runner import AppProcess # Needed during runtime - raw_process: BaseProcess = \ - multiprocessing.current_process() + raw_process: BaseProcess = multiprocessing.current_process() if isinstance(raw_process, AppProcess): return raw_process - raise RuntimeError( - "Failed to retrieve the AppProcess: running in wrong context" - ) + + raise RuntimeError( "Failed to retrieve the AppProcess: running in wrong context") class Config: @@ -737,6 +743,58 @@ def call_frontend_function(self, module_key: str, function_name: str, args: List "args": args }) +class MiddlewareExecutor(): + """ + A MiddlewareExecutor executes middleware in a controlled context. It allows writer framework + to manage different implementations of middleware. + + Case 1 : A middleware is a generator, then run before and after code + + >>> @wf.middleware() + >>> def my_middleware(): + >>> print("before event handler") + >>> yield() + >>> print("after event handler") + + Case 2 : A middleware is just a function, then run the function before + + >>> @wf.middleware() + >>> def my_middleware(): + >>> print("before event handler") + """ + + def __init__(self, middleware: Callable): + self.middleware = middleware + + @contextlib.contextmanager + def execute(self, args: dict): + middleware_args = build_writer_func_arguments(self.middleware, args) + it = self.middleware(*middleware_args) + try: + yield from it + except StopIteration: + yield + except TypeError: + yield + + +class MiddlewareRegistry: + + def __init__(self): + self.registry: List[MiddlewareExecutor] = [] + + def register(self, middleware: Callable): + me = MiddlewareExecutor(middleware) + self.registry.append(me) + + def executors(self) -> List[MiddlewareExecutor]: + """ + Retrieves middlewares prepared for execution + + >>> executors = middleware_registry.executors() + >>> result = handle_with_middlewares_executor(executors, lambda state: pass, {'state': {}, 'payload': {}}) + """ + return self.registry class EventHandlerRegistry: """ @@ -1088,7 +1146,7 @@ def get_context_data(self, instance_path: InstancePath) -> Dict[str, Any]: if len(instance_path) > 0: context['target'] = instance_path[-1]['componentId'] - + return context def set_state(self, expr: str, instance_path: InstancePath, value: Any) -> None: @@ -1304,23 +1362,13 @@ def _handle_binding(self, event_type, target_component, instance_path, payload) return self.evaluator.set_state(binding["stateRef"], instance_path, payload) - def _async_handler_executor(self, callable_handler, arg_values): - async_callable = self._async_handler_executor_internal(callable_handler, arg_values) - return asyncio.run(async_callable) - - async def _async_handler_executor_internal(self, callable_handler, arg_values): - with contextlib.redirect_stdout(io.StringIO()) as f: - result = await callable_handler(*arg_values) - captured_stdout = f.getvalue() - return result, captured_stdout - - def _sync_handler_executor(self, callable_handler, arg_values): - with contextlib.redirect_stdout(io.StringIO()) as f: - result = callable_handler(*arg_values) - captured_stdout = f.getvalue() - return result, captured_stdout - - def _call_handler_callable(self, event_type, target_component, instance_path, payload) -> Any: + def _call_handler_callable( + self, + event_type: str, + target_component: Component, + instance_path: List[InstancePathItem], + payload: Any + ) -> Any: current_app_process = get_app_process() handler_registry = current_app_process.handler_registry if not target_component.handlers: @@ -1331,44 +1379,35 @@ def _call_handler_callable(self, event_type, target_component, instance_path, pa callable_handler = handler_registry.find_handler(handler) if not callable_handler: - raise ValueError( - f"""Invalid handler. Couldn't find the handler "{ handler }".""") - is_async_handler = inspect.iscoroutinefunction(callable_handler) - - if (not callable(callable_handler) - and not is_async_handler): - raise ValueError( - "Invalid handler. The handler isn't a callable object.") - - args = inspect.getfullargspec(callable_handler).args - arg_values = [] - for arg in args: - if arg == "state": - arg_values.append(self.session_state) - elif arg == "payload": - arg_values.append(payload) - elif arg == "context": - context = self.evaluator.get_context_data(instance_path) - arg_values.append(context) - elif arg == "session": - session_info = { - "id": self.session.session_id, - "cookies": self.session.cookies, - "headers": self.session.headers, - "userinfo": self.session.userinfo or {} - } - arg_values.append(session_info) - elif arg == "ui": - from writer.ui import WriterUIManager - ui_manager = WriterUIManager() - arg_values.append(ui_manager) + raise ValueError(f"""Invalid handler. Couldn't find the handler "{ handler }".""") + + # Preparation of arguments + from writer.ui import WriterUIManager + + context_data = self.evaluator.get_context_data(instance_path) + context_data['event'] = event_type + writer_args = { + 'state': self.session_state, + 'payload': payload, + 'context': context_data, + 'session': { + 'id': self.session.session_id, + 'cookies': self.session.cookies, + 'headers': self.session.headers, + 'userinfo': self.session.userinfo or {} + }, + 'ui': WriterUIManager() + } + # Invocation of handler result = None - with core_ui.use_component_tree(self.session.session_component_tree): - if is_async_handler: - result, captured_stdout = self._async_handler_executor(callable_handler, arg_values) - else: - result, captured_stdout = self._sync_handler_executor(callable_handler, arg_values) + captured_stdout = None + with core_ui.use_component_tree(self.session.session_component_tree), \ + contextlib.redirect_stdout(io.StringIO()) as f: + middlewares_executors = current_app_process.middleware_registry.executors() + + result = handle_with_middlewares_executor(middlewares_executors, callable_handler, writer_args) + captured_stdout = f.getvalue() if captured_stdout: self.session_state.add_log_entry( @@ -1376,6 +1415,7 @@ def _call_handler_callable(self, event_type, target_component, instance_path, pa "Stdout message", captured_stdout ) + return result def handle(self, ev: WriterEvent) -> WriterEventResult: @@ -1394,11 +1434,10 @@ def handle(self, ev: WriterEvent) -> WriterEventResult: try: instance_path = ev.instancePath target_id = instance_path[-1]["componentId"] - target_component = self.session_component_tree.get_component(target_id) + target_component = cast(Component, self.session_component_tree.get_component(target_id)) self._handle_binding(ev.type, target_component, instance_path, ev.payload) - result = self._call_handler_callable( - ev.type, target_component, instance_path, ev.payload) + result = self._call_handler_callable(ev.type, target_component, instance_path, ev.payload) except BaseException: ok = False self.session_state.add_notification("error", "Runtime Error", f"An error occurred when processing event '{ ev.type }'.", @@ -1497,6 +1536,81 @@ def reset_base_component_tree() -> None: base_component_tree = core_ui.build_base_component_tree() +def handler_executor(callable_handler: Callable, writer_args: dict) -> Any: + """ + Runs a handler based on its signature. + + If the handler is asynchronous, it is executed asynchronously. + If the handler only has certain parameters, only these are passed as arguments + + >>> def my_handler(state): + >>> state['a'] = 2 + >>> + >>> handler_executor(my_handler, {'state': {'a': 1}, 'payload': None, 'context': None, 'session': None, 'ui': None}) + """ + is_async_handler = inspect.iscoroutinefunction(callable_handler) + if (not callable(callable_handler) and not is_async_handler): + raise ValueError("Invalid handler. The handler isn't a callable object.") + + handler_args = build_writer_func_arguments(callable_handler, writer_args) + + if is_async_handler: + async_wrapper = _async_wrapper_internal(callable_handler, handler_args) + result = asyncio.run(async_wrapper) + else: + result = callable_handler(*handler_args) + + return result + +def handle_with_middlewares_executor(middlewares_executors: List[MiddlewareExecutor], callable_handler: Callable, writer_args: dict) -> Any: + """ + Runs the middlewares then the handler. This function allows you to manage exceptions that are triggered in middleware + + :param middlewares_executors: The list of middleware to run + :param callable_handler: The target handler + + >>> @wf.middleware() + >>> def my_middleware(state, payload, context, session, ui): + >>> yield + + >>> executor = MiddlewareExecutor(my_middleware, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None}) + >>> handle_with_middlewares_executor([executor], my_handler, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None} + """ + if len(middlewares_executors) == 0: + return handler_executor(callable_handler, writer_args) + else: + executor = middlewares_executors[0] + with executor.execute(writer_args): + return handle_with_middlewares_executor(middlewares_executors[1:], callable_handler, writer_args) + +def build_writer_func_arguments(func: Callable, writer_args: dict) -> List[Any]: + """ + Constructs the list of arguments based on the signature of the function + which can be a handler or middleware. + + >>> def my_event_handler(state, context): + >>> yield + + >>> args = build_writer_func_arguments(my_event_handler, {'state': {}, 'payload': {}, 'context': {"target": '11'}, 'session': None, 'ui': None}) + >>> [{}, {"target": '11'}] + + :param func: the function that will be called + :param writer_args: the possible arguments in writer (state, payload, ...) + """ + handler_args = inspect.getfullargspec(func).args + func_args = [] + for arg in handler_args: + if arg in writer_args: + func_args.append(writer_args[arg]) + + return func_args + + +async def _async_wrapper_internal(callable_handler: Callable, arg_values: List[Any]) -> Any: + result = await callable_handler(*arg_values) + return result + + state_serialiser = StateSerialiser() initial_state = WriterState() base_component_tree = core_ui.build_base_component_tree() diff --git a/src/writer/ss_types.py b/src/writer/ss_types.py index 090336453..d82d2cc17 100644 --- a/src/writer/ss_types.py +++ b/src/writer/ss_types.py @@ -26,7 +26,7 @@ def read(self) -> Any: ServeMode = Literal["run", "edit"] MessageType = Literal["sessionInit", "componentUpdate", "event", "codeUpdate", "codeSave", "checkSession", - "keepAlive", "stateEnquiry", "setUserinfo"] + "keepAlive", "stateEnquiry", "setUserinfo", "stateContent"] # Web server models @@ -107,6 +107,10 @@ class StateEnquiryRequest(AppProcessServerRequest): type: Literal["stateEnquiry"] +class StateContentRequest(AppProcessServerRequest): + type: Literal["stateContent"] + + AppProcessServerRequestPacket = Tuple[int, Optional[str], AppProcessServerRequest] @@ -143,6 +147,9 @@ class StateEnquiryResponsePayload(BaseModel): mutations: Dict[str, Any] mail: List +class StateContentResponsePayload(BaseModel): + state: Dict[str, Any] + class EventResponse(AppProcessServerResponse): type: Literal["event"] diff --git a/tests/backend/fixtures/app_runner_fixtures.py b/tests/backend/fixtures/app_runner_fixtures.py new file mode 100644 index 000000000..45f4a64de --- /dev/null +++ b/tests/backend/fixtures/app_runner_fixtures.py @@ -0,0 +1,34 @@ +from typing import Optional + +from writer.app_runner import AppRunner +from writer.ss_types import InitSessionRequestPayload + +FIXED_SESSION_ID = "0000000000000000000000000000000000000000000000000000000000000000" # Compliant session number + +async def init_app_session(app_runner: AppRunner, + session_id: str = None, + cookies: Optional[dict] = None, + headers: Optional[dict] = None) -> str: + """ + Fixture to initialize a session and be able to use it in tests. + + If the `session_id` is missing from the parameters, the fixture creates a session with a random ID. + + >>> with setup_app_runner(test_app_dir, 'run') as ar: + >>> # When + >>> ar.load() + >>> session_id = await init_app_session(ar) + + If the `session_id` is missing from the parameters, the fixture creates a session for this identifier. + + >>> session_id = await init_app_session(ar, session_id=FIXED_SESSION_ID) + """ + if cookies is None: + cookies = {} + if headers is None: + headers = {} + + init_session_payload = InitSessionRequestPayload(cookies=cookies, headers=headers, proposedSessionId=session_id) + result = await app_runner.init_session(init_session_payload) + + return result.payload.model_dump().get("sessionId") diff --git a/tests/backend/test_app_runner.py b/tests/backend/test_app_runner.py index 1cb7bb304..4f3c771c2 100644 --- a/tests/backend/test_app_runner.py +++ b/tests/backend/test_app_runner.py @@ -1,6 +1,5 @@ import asyncio -import contextlib import threading import pytest @@ -12,21 +11,10 @@ WriterEvent, ) +from backend.fixtures.app_runner_fixtures import init_app_session from tests.backend import test_app_dir -@pytest.fixture -def setup_app_runner(): - @contextlib.contextmanager - def _manage_launch_args(app_dir, app_command): - ar = AppRunner(app_dir, app_command) - try: - yield ar - finally: - ar.shut_down() - return _manage_launch_args - - class TestAppRunner: numberinput_instance_path = [ @@ -60,7 +48,7 @@ def test_init_wrong_mode(self) -> None: @pytest.mark.asyncio @pytest.mark.usefixtures("setup_app_runner") async def test_pre_session(self, setup_app_runner) -> None: - with setup_app_runner(test_app_dir, "run") as ar: + with setup_app_runner(test_app_dir, "run", load = True) as ar: er = EventRequest( type="event", payload=WriterEvent( @@ -71,26 +59,14 @@ async def test_pre_session(self, setup_app_runner) -> None: } ) ) - ar.load() r = await ar.dispatch_message(None, er) assert r.status == "error" @pytest.mark.asyncio @pytest.mark.usefixtures("setup_app_runner") async def test_valid_session_invalid_event(self, setup_app_runner) -> None: - with setup_app_runner(test_app_dir, "run") as ar: - ar.load() - si = InitSessionRequest( - type="sessionInit", - payload=InitSessionRequestPayload( - cookies={}, - headers={}, - proposedSessionId=self.proposed_session_id - ) - ) - sres = await ar.dispatch_message(None, si) - assert sres.status == "ok" - assert sres.payload.model_dump().get("sessionId") == self.proposed_session_id + with setup_app_runner(test_app_dir, "run", load = True) as ar: + await init_app_session(ar, session_id=self.proposed_session_id) er = EventRequest(type="event", payload=WriterEvent( type="virus", instancePath=self.numberinput_instance_path, @@ -104,19 +80,8 @@ async def test_valid_session_invalid_event(self, setup_app_runner) -> None: @pytest.mark.asyncio @pytest.mark.usefixtures("setup_app_runner") async def test_valid_event(self, setup_app_runner) -> None: - with setup_app_runner(test_app_dir, "run") as ar: - ar.load() - si = InitSessionRequest( - type="sessionInit", - payload=InitSessionRequestPayload( - cookies={}, - headers={}, - proposedSessionId=self.proposed_session_id - ) - ) - sres = await ar.dispatch_message(None, si) - assert sres.status == "ok" - assert sres.payload.model_dump().get("sessionId") == self.proposed_session_id + with setup_app_runner(test_app_dir, "run", load = True) as ar: + await init_app_session(ar, session_id=self.proposed_session_id) ev_req = EventRequest(type="event", payload=WriterEvent( type="wf-number-change", instancePath=self.numberinput_instance_path, @@ -133,20 +98,8 @@ async def test_valid_event(self, setup_app_runner) -> None: @pytest.mark.asyncio @pytest.mark.usefixtures("setup_app_runner") async def test_async_handler(self, setup_app_runner) -> None: - with setup_app_runner(test_app_dir, "run") as ar: - ar.load() - si = InitSessionRequest( - type="sessionInit", - payload=InitSessionRequestPayload( - cookies={}, - headers={}, - proposedSessionId=self.proposed_session_id - ) - ) - sres = await ar.dispatch_message(None, si) - assert sres.status == "ok" - assert sres.payload.model_dump().get("sessionId") == self.proposed_session_id - + with setup_app_runner(test_app_dir, "run", load = True) as ar: + await init_app_session(ar, session_id=self.proposed_session_id) # Firing an event to bypass "initial" state mutations ev_req = EventRequest(type="event", payload=WriterEvent( type="wf-number-change", @@ -167,17 +120,8 @@ async def test_async_handler(self, setup_app_runner) -> None: @pytest.mark.asyncio @pytest.mark.usefixtures("setup_app_runner") async def test_bad_event_handler(self, setup_app_runner) -> None: - with setup_app_runner(test_app_dir, "run") as ar: - ar.load() - si = InitSessionRequest( - type="sessionInit", - payload=InitSessionRequestPayload( - cookies={}, - headers={}, - proposedSessionId=self.proposed_session_id - ) - ) - await ar.dispatch_message(None, si) + with setup_app_runner(test_app_dir, "run", load = True) as ar: + await init_app_session(ar, session_id=self.proposed_session_id) bad_button_instance_path = [ {"componentId": "root", "instanceNumber": 0}, {"componentId": "28a2212b-bc58-4398-8a72-2554e5296490", "instanceNumber": 0}, @@ -238,3 +182,25 @@ async def test_code_update(self, setup_app_runner) -> None: mail = list(si_res.payload.model_dump().get("mail")) assert mail[0].get("payload").get("message") == "188542\n" + + @pytest.mark.asyncio + @pytest.mark.usefixtures("setup_app_runner") + async def test_handle_event_should_return_result_of_event_handler_execution(self, setup_app_runner): + """ + Tests that an event handler should result the result of function execution in + payload.result['result']. + """ + # Given + ar: AppRunner + with setup_app_runner(test_app_dir, 'run', load=True) as ar: + session_id = await init_app_session(ar) + + # When + res = await ar.handle_event(session_id, WriterEvent( + type='click', + instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}], + payload={}) + ) + + # Then + assert res.payload.result['result'] is not None diff --git a/tests/backend/test_middleware.py b/tests/backend/test_middleware.py new file mode 100644 index 000000000..771087983 --- /dev/null +++ b/tests/backend/test_middleware.py @@ -0,0 +1,111 @@ +import pytest +from writer.app_runner import AppRunner +from writer.ss_types import WriterEvent + +from backend import test_app_dir +from backend.fixtures.app_runner_fixtures import init_app_session + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_app_runner") +async def test_middleware_should_apply_on_every_event_handler_invocation(setup_app_runner): + """ + Tests that a middleware executes before an event + """ + # Given + ar: AppRunner + with setup_app_runner(test_app_dir, 'run', load=True) as ar: + session_id = await init_app_session(ar) + + # When + res = await ar.handle_event(session_id, WriterEvent( + type='click', + instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}], + payload={}) + ) + res.payload.mutations['+counter_middleware'] = 1 + + # Then + full_state = await ar.handle_state_content(session_id) + assert full_state.payload.state['counter'] == 3 + assert full_state.payload.state['counter_middleware'] == 1 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_app_runner") +async def test_middleware_should_apply_on_multi_event_handler_invocation(setup_app_runner): + """ + Tests that a middleware executes twice after 2 different events + """ + # Given + ar: AppRunner + with setup_app_runner(test_app_dir, 'run', load=True) as ar: + session_id = await init_app_session(ar) + + # When + res = await ar.handle_event(session_id, WriterEvent( + type='click', + instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}], + payload={}) + ) + res.payload.mutations['+counter_middleware'] = 1 + + res = await ar.handle_event(session_id, WriterEvent( + type='wf-option-change', + instancePath=[{'componentId': '2e46c38b-6405-42ad-ad9c-d237a53a7d30', 'instanceNumber': 0}], + payload='ar') + ) + res.payload.mutations['+counter_middleware'] = 2 + + # Then + full_state = await ar.handle_state_content(session_id) + assert full_state.payload.state['counter_middleware'] == 2 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_app_runner") +async def test_middleware_should_apply_after_event_handler_invocation(setup_app_runner): + """ + Test that a middleware executes after an event + """ + # Given + ar: AppRunner + with setup_app_runner(test_app_dir, 'run', load=True) as ar: + session_id = await init_app_session(ar) + + # When + res = await ar.handle_event(session_id, WriterEvent( + type='click', + instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}], + payload={}) + ) + res.payload.mutations['+counter_post_middleware'] = 1 + + # Then + full_state = await ar.handle_state_content(session_id) + assert full_state.payload.state['counter'] == 3 + assert full_state.payload.state['counter_post_middleware'] == 1 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_app_runner") +async def test_middleware_should_apply_on_middleware_without_yield(setup_app_runner): + """ + Test that a middleware executes after an event + """ + # Given + ar: AppRunner + with setup_app_runner(test_app_dir, 'run', load=True) as ar: + session_id = await init_app_session(ar) + + # When + res = await ar.handle_event(session_id, WriterEvent( + type='click', + instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}], + payload={}) + ) + res.payload.mutations['+counter_middleware_without_yield'] = 1 + + # Then + full_state = await ar.handle_state_content(session_id) + assert full_state.payload.state['counter_middleware_without_yield'] == 1 diff --git a/tests/backend/testapp/main.py b/tests/backend/testapp/main.py index 171ac01b0..53df2ac09 100644 --- a/tests/backend/testapp/main.py +++ b/tests/backend/testapp/main.py @@ -10,6 +10,20 @@ import writer.core +@wf.middleware() +def my_middleware(state): + state['counter_middleware'] += 1 + yield + +@wf.middleware() +def no_yield_middleware(state): + state['counter_middleware_without_yield'] += 1 + +@wf.middleware() +def post_middleware(state): + yield + state['counter_post_middleware'] += 1 + @wf.session_verifier def check_headers(headers): if headers.get("x-fail") is not None: @@ -43,6 +57,7 @@ def update_cities(state, payload): def increment(state): state["counter"] += 1*my_var + return 1 # EVENT HANDLERS @@ -220,6 +235,9 @@ def _get_altair_chart(): "min_weight": 300, }, "counter": 0, + "counter_middleware": 0, + "counter_post_middleware": 0, + "counter_middleware_without_yield": 0, "metrics": {}, "b": { "pet_count": 8 diff --git a/tests/conftest.py b/tests/conftest.py index 21036f36b..f75f0b495 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,10 @@ +import contextlib +from typing import Literal + +import pytest +from writer.app_runner import AppRunner + + def pytest_collection_modifyitems(config, items): if not config.getoption("--full-run"): deselected = [] @@ -15,3 +22,32 @@ def pytest_addoption(parser): parser.addoption( "--full-run", action="store_true", default=False, help="Include explicit-marked tests in the run (those are exluded from regular runs)" ) + +@pytest.fixture +def setup_app_runner(): + @contextlib.contextmanager + def _manage_launch_args(app_dir: str, app_command: Literal["run", "edit"], load: bool = False): + """ + Fixture to instantiate a writer application for testing. + + >>> with setup_app_runner("app_dir", "run", load=True) as ar: + >>> pass + + When the load flag is True, the application is loaded. + + >>> with setup_app_runner("app_dir", "run", load=True) as ar: + >>> pass + + :param app_dir: the folder that contains the application + :param app_command: the execution mode of the application, either edit or run + :param load: load the application if True + """ + ar = AppRunner(app_dir, app_command) + try: + if load is True: + ar.load() + + yield ar + finally: + ar.shut_down() + return _manage_launch_args