diff --git a/docs/framework/event-handlers.mdx b/docs/framework/event-handlers.mdx
index 8353d2510..6617c8357 100644
--- a/docs/framework/event-handlers.mdx
+++ b/docs/framework/event-handlers.mdx
@@ -172,6 +172,41 @@ def evaluate(state, payload):
Take into account that globals apply to all users. If you need to store data that's only relevant to a particular user, use application state.
+## Middlewares
+
+Middlewares are functions that run before and after every event handler.
+They can be used to perform tasks such as logging, error handling, session management, or modifying the state.
+
+```py
+import writer as wf
+
+@wf.middleware()
+def middleware_before(state, payload, context):
+ print("Middleware before event handler")
+ state['running'] += 1
+ yield
+ print("Middleware after event handler")
+ state['running'] -= 1
+```
+
+A middleware receives the same parameters as an event handler.
+
+A middleware can be used to handle exceptions that happens in event handlers.
+
+```py
+import writer as wf
+
+@wf.middleware()
+def middleware_before(state):
+ try:
+ yield
+ except Exception as e:
+ state['error_counter'] += 1
+ state['last_error'] = str()
+ finally:
+ pass
+```
+
## Standard output
The standard output of an app is captured and shown in the code editor's log. You can use the standard `print` function to output results.
@@ -232,8 +267,9 @@ You can use any awaitable object within an async event handler. This includes th
## Context
-The `context` argument provides additional information about the event. For example, if the event
-was triggered by a _Button_, the `context` will include target field that contains the id of the button.
+The `context` argument provides additional information about the event.
+
+The context provide the id of component that trigger the event in `target` field.
```py
def handle_click(state, context: dict):
@@ -241,6 +277,15 @@ def handle_click(state, context: dict):
state["last_source_of_click"] = last_source_of_click
```
+The context provides the event triggered in the `event` field.
+
+```py
+def handle_click(state, context: dict):
+ event_type = context['event']
+ if event_type == 'click':
+ state["last_event"] = 'Click'
+```
+
The repeater components have additional fields in the context, such as defined in `keyVariable` and `valueVariable`.
```py
diff --git a/src/ui/src/builder/BuilderFieldsPadding.vue b/src/ui/src/builder/BuilderFieldsPadding.vue
index 53b2967a5..5c87cf61f 100644
--- a/src/ui/src/builder/BuilderFieldsPadding.vue
+++ b/src/ui/src/builder/BuilderFieldsPadding.vue
@@ -42,7 +42,7 @@
@select="handleInputSelect"
/>
- padding
+ All
- padding
+ X
px
- padding
+ Y
- padding
+ Left
px
- padding
+ Right
px
- padding
+ Top
px
- padding
+ Bottom
{
flex-direction: row;
gap: 8px;
padding: 8px;
- align-items: center;
+ align-items: baseline;
}
.row input {
width: calc(100% - 32px) !important;
+ text-align: right;
}
diff --git a/src/ui/src/builder/BuilderSelect.vue b/src/ui/src/builder/BuilderSelect.vue
index 542b941f6..19439813c 100644
--- a/src/ui/src/builder/BuilderSelect.vue
+++ b/src/ui/src/builder/BuilderSelect.vue
@@ -152,6 +152,7 @@ const select = (event) => {
display: block;
padding: 8px;
font-weight: 400;
+ font-size: 0.75rem;
color: #000000e6;
cursor: pointer;
transition: all 0.2s;
@@ -182,4 +183,8 @@ const select = (event) => {
flex-direction: row;
gap: 8px;
}
+
+.selectContent {
+ font-size: 0.75rem;
+}
diff --git a/src/writer/__init__.py b/src/writer/__init__.py
index 3a2ac7863..b93a03f1d 100644
--- a/src/writer/__init__.py
+++ b/src/writer/__init__.py
@@ -10,6 +10,7 @@
State,
WriterState,
base_component_tree,
+ get_app_process,
initial_state,
new_initial_state,
session_manager,
@@ -128,3 +129,33 @@ def init_handlers(handler_modules: Union[List[ModuleType], ModuleType]):
for module in handler_modules:
handler_registry.register_module(module)
+
+
+def middleware():
+ """
+ A "middleware" is a function that works with every event handler before it is processed and also before returning it.
+
+ >>> import writer as wf
+ >>>
+ >>> @wf.middleware()
+ >>> def my_middleware(state):
+ >>> state['processing'] += 1
+ >>> yield
+ >>> state['processing'] -= 1
+
+ Middleware accepts the same arguments as an event handler.
+
+ >>> import writer as wf
+ >>>
+ >>> @wf.middleware()
+ >>> def my_middleware(state, payload, session):
+ >>> state['processing'] += 1
+ >>> yield
+ >>> state['processing'] -= 1
+ """
+ def inner(func):
+ _app_process = get_app_process()
+ _app_process.middleware_registry.register(func)
+
+
+ return inner
diff --git a/src/writer/app_runner.py b/src/writer/app_runner.py
index e4562d565..a6cb0fa90 100644
--- a/src/writer/app_runner.py
+++ b/src/writer/app_runner.py
@@ -19,7 +19,7 @@
from watchdog.observers.polling import PollingObserver
from writer import VERSION
-from writer.core import EventHandlerRegistry, WriterSession
+from writer.core import EventHandlerRegistry, MiddlewareRegistry, WriterSession
from writer.core_ui import ingest_bmc_component_tree
from writer.ss_types import (
AppProcessServerRequest,
@@ -33,6 +33,8 @@
InitSessionRequest,
InitSessionRequestPayload,
InitSessionResponsePayload,
+ StateContentRequest,
+ StateContentResponsePayload,
StateEnquiryRequest,
StateEnquiryResponsePayload,
WriterEvent,
@@ -96,6 +98,7 @@ def __init__(self,
self.is_app_process_server_failed = is_app_process_server_failed
self.logger = logging.getLogger("app")
self.handler_registry = EventHandlerRegistry()
+ self.middleware_registry = MiddlewareRegistry()
def _load_module(self) -> ModuleType:
@@ -204,6 +207,19 @@ def _handle_state_enquiry(self, session: WriterSession) -> StateEnquiryResponseP
session.session_state.clear_mail()
return res_payload
+
+ def _handle_state_content(self, session: WriterSession) -> StateContentResponsePayload:
+ serialized_state = {}
+ try:
+ serialized_state = session.session_state.user_state.to_raw_state()
+ except BaseException:
+ import traceback as tb
+ session.session_state.add_log_entry("error",
+ "Serialisation Error",
+ "An exception was raised during serialisation.",
+ tb.format_exc())
+
+ return StateContentResponsePayload(state=serialized_state)
def _handle_component_update(self, session: WriterSession, payload: ComponentUpdateRequestPayload) -> None:
import writer
@@ -255,6 +271,13 @@ def _handle_message(self, session_id: str, request: AppProcessServerRequest) ->
payload=self._handle_state_enquiry(session)
)
+ if type == "stateContent":
+ return AppProcessServerResponse(
+ status="ok",
+ status_message=None,
+ payload=self._handle_state_content(session)
+ )
+
if type == "setUserinfo":
session.userinfo = request.payload
return AppProcessServerResponse(
@@ -713,6 +736,16 @@ async def handle_state_enquiry(self, session_id: str) -> AppProcessServerRespons
type="stateEnquiry"
))
+ async def handle_state_content(self, session_id: str) -> AppProcessServerResponse:
+ """
+ This method returns the complete status of the application.
+
+ It is only accessible through tests
+ """
+ return await self.dispatch_message(session_id, StateContentRequest(
+ type="stateContent"
+ ))
+
def save_code(self, session_id: str, code: str) -> None:
if self.mode != "edit":
raise PermissionError("Cannot save code in non-edit mode.")
diff --git a/src/writer/core.py b/src/writer/core.py
index 1ad78ee44..c72a755f7 100644
--- a/src/writer/core.py
+++ b/src/writer/core.py
@@ -36,8 +36,10 @@
)
from writer import core_ui
+from writer.core_ui import Component
from writer.ss_types import (
InstancePath,
+ InstancePathItem,
Readable,
WriterEvent,
WriterEventResult,
@@ -49,14 +51,18 @@
def get_app_process() -> 'AppProcess':
+ """
+ Retrieves the Writer Framework process context.
+
+ >>> _current_process = get_app_process()
+ >>> _current_process.bmc_components # get the component tree
+ """
from writer.app_runner import AppProcess # Needed during runtime
- raw_process: BaseProcess = \
- multiprocessing.current_process()
+ raw_process: BaseProcess = multiprocessing.current_process()
if isinstance(raw_process, AppProcess):
return raw_process
- raise RuntimeError(
- "Failed to retrieve the AppProcess: running in wrong context"
- )
+
+ raise RuntimeError( "Failed to retrieve the AppProcess: running in wrong context")
class Config:
@@ -737,6 +743,58 @@ def call_frontend_function(self, module_key: str, function_name: str, args: List
"args": args
})
+class MiddlewareExecutor():
+ """
+ A MiddlewareExecutor executes middleware in a controlled context. It allows writer framework
+ to manage different implementations of middleware.
+
+ Case 1 : A middleware is a generator, then run before and after code
+
+ >>> @wf.middleware()
+ >>> def my_middleware():
+ >>> print("before event handler")
+ >>> yield()
+ >>> print("after event handler")
+
+ Case 2 : A middleware is just a function, then run the function before
+
+ >>> @wf.middleware()
+ >>> def my_middleware():
+ >>> print("before event handler")
+ """
+
+ def __init__(self, middleware: Callable):
+ self.middleware = middleware
+
+ @contextlib.contextmanager
+ def execute(self, args: dict):
+ middleware_args = build_writer_func_arguments(self.middleware, args)
+ it = self.middleware(*middleware_args)
+ try:
+ yield from it
+ except StopIteration:
+ yield
+ except TypeError:
+ yield
+
+
+class MiddlewareRegistry:
+
+ def __init__(self):
+ self.registry: List[MiddlewareExecutor] = []
+
+ def register(self, middleware: Callable):
+ me = MiddlewareExecutor(middleware)
+ self.registry.append(me)
+
+ def executors(self) -> List[MiddlewareExecutor]:
+ """
+ Retrieves middlewares prepared for execution
+
+ >>> executors = middleware_registry.executors()
+ >>> result = handle_with_middlewares_executor(executors, lambda state: pass, {'state': {}, 'payload': {}})
+ """
+ return self.registry
class EventHandlerRegistry:
"""
@@ -1088,7 +1146,7 @@ def get_context_data(self, instance_path: InstancePath) -> Dict[str, Any]:
if len(instance_path) > 0:
context['target'] = instance_path[-1]['componentId']
-
+
return context
def set_state(self, expr: str, instance_path: InstancePath, value: Any) -> None:
@@ -1304,23 +1362,13 @@ def _handle_binding(self, event_type, target_component, instance_path, payload)
return
self.evaluator.set_state(binding["stateRef"], instance_path, payload)
- def _async_handler_executor(self, callable_handler, arg_values):
- async_callable = self._async_handler_executor_internal(callable_handler, arg_values)
- return asyncio.run(async_callable)
-
- async def _async_handler_executor_internal(self, callable_handler, arg_values):
- with contextlib.redirect_stdout(io.StringIO()) as f:
- result = await callable_handler(*arg_values)
- captured_stdout = f.getvalue()
- return result, captured_stdout
-
- def _sync_handler_executor(self, callable_handler, arg_values):
- with contextlib.redirect_stdout(io.StringIO()) as f:
- result = callable_handler(*arg_values)
- captured_stdout = f.getvalue()
- return result, captured_stdout
-
- def _call_handler_callable(self, event_type, target_component, instance_path, payload) -> Any:
+ def _call_handler_callable(
+ self,
+ event_type: str,
+ target_component: Component,
+ instance_path: List[InstancePathItem],
+ payload: Any
+ ) -> Any:
current_app_process = get_app_process()
handler_registry = current_app_process.handler_registry
if not target_component.handlers:
@@ -1331,44 +1379,35 @@ def _call_handler_callable(self, event_type, target_component, instance_path, pa
callable_handler = handler_registry.find_handler(handler)
if not callable_handler:
- raise ValueError(
- f"""Invalid handler. Couldn't find the handler "{ handler }".""")
- is_async_handler = inspect.iscoroutinefunction(callable_handler)
-
- if (not callable(callable_handler)
- and not is_async_handler):
- raise ValueError(
- "Invalid handler. The handler isn't a callable object.")
-
- args = inspect.getfullargspec(callable_handler).args
- arg_values = []
- for arg in args:
- if arg == "state":
- arg_values.append(self.session_state)
- elif arg == "payload":
- arg_values.append(payload)
- elif arg == "context":
- context = self.evaluator.get_context_data(instance_path)
- arg_values.append(context)
- elif arg == "session":
- session_info = {
- "id": self.session.session_id,
- "cookies": self.session.cookies,
- "headers": self.session.headers,
- "userinfo": self.session.userinfo or {}
- }
- arg_values.append(session_info)
- elif arg == "ui":
- from writer.ui import WriterUIManager
- ui_manager = WriterUIManager()
- arg_values.append(ui_manager)
+ raise ValueError(f"""Invalid handler. Couldn't find the handler "{ handler }".""")
+
+ # Preparation of arguments
+ from writer.ui import WriterUIManager
+
+ context_data = self.evaluator.get_context_data(instance_path)
+ context_data['event'] = event_type
+ writer_args = {
+ 'state': self.session_state,
+ 'payload': payload,
+ 'context': context_data,
+ 'session': {
+ 'id': self.session.session_id,
+ 'cookies': self.session.cookies,
+ 'headers': self.session.headers,
+ 'userinfo': self.session.userinfo or {}
+ },
+ 'ui': WriterUIManager()
+ }
+ # Invocation of handler
result = None
- with core_ui.use_component_tree(self.session.session_component_tree):
- if is_async_handler:
- result, captured_stdout = self._async_handler_executor(callable_handler, arg_values)
- else:
- result, captured_stdout = self._sync_handler_executor(callable_handler, arg_values)
+ captured_stdout = None
+ with core_ui.use_component_tree(self.session.session_component_tree), \
+ contextlib.redirect_stdout(io.StringIO()) as f:
+ middlewares_executors = current_app_process.middleware_registry.executors()
+
+ result = handle_with_middlewares_executor(middlewares_executors, callable_handler, writer_args)
+ captured_stdout = f.getvalue()
if captured_stdout:
self.session_state.add_log_entry(
@@ -1376,6 +1415,7 @@ def _call_handler_callable(self, event_type, target_component, instance_path, pa
"Stdout message",
captured_stdout
)
+
return result
def handle(self, ev: WriterEvent) -> WriterEventResult:
@@ -1394,11 +1434,10 @@ def handle(self, ev: WriterEvent) -> WriterEventResult:
try:
instance_path = ev.instancePath
target_id = instance_path[-1]["componentId"]
- target_component = self.session_component_tree.get_component(target_id)
+ target_component = cast(Component, self.session_component_tree.get_component(target_id))
self._handle_binding(ev.type, target_component, instance_path, ev.payload)
- result = self._call_handler_callable(
- ev.type, target_component, instance_path, ev.payload)
+ result = self._call_handler_callable(ev.type, target_component, instance_path, ev.payload)
except BaseException:
ok = False
self.session_state.add_notification("error", "Runtime Error", f"An error occurred when processing event '{ ev.type }'.",
@@ -1497,6 +1536,81 @@ def reset_base_component_tree() -> None:
base_component_tree = core_ui.build_base_component_tree()
+def handler_executor(callable_handler: Callable, writer_args: dict) -> Any:
+ """
+ Runs a handler based on its signature.
+
+ If the handler is asynchronous, it is executed asynchronously.
+ If the handler only has certain parameters, only these are passed as arguments
+
+ >>> def my_handler(state):
+ >>> state['a'] = 2
+ >>>
+ >>> handler_executor(my_handler, {'state': {'a': 1}, 'payload': None, 'context': None, 'session': None, 'ui': None})
+ """
+ is_async_handler = inspect.iscoroutinefunction(callable_handler)
+ if (not callable(callable_handler) and not is_async_handler):
+ raise ValueError("Invalid handler. The handler isn't a callable object.")
+
+ handler_args = build_writer_func_arguments(callable_handler, writer_args)
+
+ if is_async_handler:
+ async_wrapper = _async_wrapper_internal(callable_handler, handler_args)
+ result = asyncio.run(async_wrapper)
+ else:
+ result = callable_handler(*handler_args)
+
+ return result
+
+def handle_with_middlewares_executor(middlewares_executors: List[MiddlewareExecutor], callable_handler: Callable, writer_args: dict) -> Any:
+ """
+ Runs the middlewares then the handler. This function allows you to manage exceptions that are triggered in middleware
+
+ :param middlewares_executors: The list of middleware to run
+ :param callable_handler: The target handler
+
+ >>> @wf.middleware()
+ >>> def my_middleware(state, payload, context, session, ui):
+ >>> yield
+
+ >>> executor = MiddlewareExecutor(my_middleware, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None})
+ >>> handle_with_middlewares_executor([executor], my_handler, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None}
+ """
+ if len(middlewares_executors) == 0:
+ return handler_executor(callable_handler, writer_args)
+ else:
+ executor = middlewares_executors[0]
+ with executor.execute(writer_args):
+ return handle_with_middlewares_executor(middlewares_executors[1:], callable_handler, writer_args)
+
+def build_writer_func_arguments(func: Callable, writer_args: dict) -> List[Any]:
+ """
+ Constructs the list of arguments based on the signature of the function
+ which can be a handler or middleware.
+
+ >>> def my_event_handler(state, context):
+ >>> yield
+
+ >>> args = build_writer_func_arguments(my_event_handler, {'state': {}, 'payload': {}, 'context': {"target": '11'}, 'session': None, 'ui': None})
+ >>> [{}, {"target": '11'}]
+
+ :param func: the function that will be called
+ :param writer_args: the possible arguments in writer (state, payload, ...)
+ """
+ handler_args = inspect.getfullargspec(func).args
+ func_args = []
+ for arg in handler_args:
+ if arg in writer_args:
+ func_args.append(writer_args[arg])
+
+ return func_args
+
+
+async def _async_wrapper_internal(callable_handler: Callable, arg_values: List[Any]) -> Any:
+ result = await callable_handler(*arg_values)
+ return result
+
+
state_serialiser = StateSerialiser()
initial_state = WriterState()
base_component_tree = core_ui.build_base_component_tree()
diff --git a/src/writer/ss_types.py b/src/writer/ss_types.py
index 090336453..d82d2cc17 100644
--- a/src/writer/ss_types.py
+++ b/src/writer/ss_types.py
@@ -26,7 +26,7 @@ def read(self) -> Any:
ServeMode = Literal["run", "edit"]
MessageType = Literal["sessionInit", "componentUpdate",
"event", "codeUpdate", "codeSave", "checkSession",
- "keepAlive", "stateEnquiry", "setUserinfo"]
+ "keepAlive", "stateEnquiry", "setUserinfo", "stateContent"]
# Web server models
@@ -107,6 +107,10 @@ class StateEnquiryRequest(AppProcessServerRequest):
type: Literal["stateEnquiry"]
+class StateContentRequest(AppProcessServerRequest):
+ type: Literal["stateContent"]
+
+
AppProcessServerRequestPacket = Tuple[int,
Optional[str], AppProcessServerRequest]
@@ -143,6 +147,9 @@ class StateEnquiryResponsePayload(BaseModel):
mutations: Dict[str, Any]
mail: List
+class StateContentResponsePayload(BaseModel):
+ state: Dict[str, Any]
+
class EventResponse(AppProcessServerResponse):
type: Literal["event"]
diff --git a/tests/backend/fixtures/app_runner_fixtures.py b/tests/backend/fixtures/app_runner_fixtures.py
new file mode 100644
index 000000000..45f4a64de
--- /dev/null
+++ b/tests/backend/fixtures/app_runner_fixtures.py
@@ -0,0 +1,34 @@
+from typing import Optional
+
+from writer.app_runner import AppRunner
+from writer.ss_types import InitSessionRequestPayload
+
+FIXED_SESSION_ID = "0000000000000000000000000000000000000000000000000000000000000000" # Compliant session number
+
+async def init_app_session(app_runner: AppRunner,
+ session_id: str = None,
+ cookies: Optional[dict] = None,
+ headers: Optional[dict] = None) -> str:
+ """
+ Fixture to initialize a session and be able to use it in tests.
+
+ If the `session_id` is missing from the parameters, the fixture creates a session with a random ID.
+
+ >>> with setup_app_runner(test_app_dir, 'run') as ar:
+ >>> # When
+ >>> ar.load()
+ >>> session_id = await init_app_session(ar)
+
+ If the `session_id` is missing from the parameters, the fixture creates a session for this identifier.
+
+ >>> session_id = await init_app_session(ar, session_id=FIXED_SESSION_ID)
+ """
+ if cookies is None:
+ cookies = {}
+ if headers is None:
+ headers = {}
+
+ init_session_payload = InitSessionRequestPayload(cookies=cookies, headers=headers, proposedSessionId=session_id)
+ result = await app_runner.init_session(init_session_payload)
+
+ return result.payload.model_dump().get("sessionId")
diff --git a/tests/backend/test_app_runner.py b/tests/backend/test_app_runner.py
index 1cb7bb304..4f3c771c2 100644
--- a/tests/backend/test_app_runner.py
+++ b/tests/backend/test_app_runner.py
@@ -1,6 +1,5 @@
import asyncio
-import contextlib
import threading
import pytest
@@ -12,21 +11,10 @@
WriterEvent,
)
+from backend.fixtures.app_runner_fixtures import init_app_session
from tests.backend import test_app_dir
-@pytest.fixture
-def setup_app_runner():
- @contextlib.contextmanager
- def _manage_launch_args(app_dir, app_command):
- ar = AppRunner(app_dir, app_command)
- try:
- yield ar
- finally:
- ar.shut_down()
- return _manage_launch_args
-
-
class TestAppRunner:
numberinput_instance_path = [
@@ -60,7 +48,7 @@ def test_init_wrong_mode(self) -> None:
@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_app_runner")
async def test_pre_session(self, setup_app_runner) -> None:
- with setup_app_runner(test_app_dir, "run") as ar:
+ with setup_app_runner(test_app_dir, "run", load = True) as ar:
er = EventRequest(
type="event",
payload=WriterEvent(
@@ -71,26 +59,14 @@ async def test_pre_session(self, setup_app_runner) -> None:
}
)
)
- ar.load()
r = await ar.dispatch_message(None, er)
assert r.status == "error"
@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_app_runner")
async def test_valid_session_invalid_event(self, setup_app_runner) -> None:
- with setup_app_runner(test_app_dir, "run") as ar:
- ar.load()
- si = InitSessionRequest(
- type="sessionInit",
- payload=InitSessionRequestPayload(
- cookies={},
- headers={},
- proposedSessionId=self.proposed_session_id
- )
- )
- sres = await ar.dispatch_message(None, si)
- assert sres.status == "ok"
- assert sres.payload.model_dump().get("sessionId") == self.proposed_session_id
+ with setup_app_runner(test_app_dir, "run", load = True) as ar:
+ await init_app_session(ar, session_id=self.proposed_session_id)
er = EventRequest(type="event", payload=WriterEvent(
type="virus",
instancePath=self.numberinput_instance_path,
@@ -104,19 +80,8 @@ async def test_valid_session_invalid_event(self, setup_app_runner) -> None:
@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_app_runner")
async def test_valid_event(self, setup_app_runner) -> None:
- with setup_app_runner(test_app_dir, "run") as ar:
- ar.load()
- si = InitSessionRequest(
- type="sessionInit",
- payload=InitSessionRequestPayload(
- cookies={},
- headers={},
- proposedSessionId=self.proposed_session_id
- )
- )
- sres = await ar.dispatch_message(None, si)
- assert sres.status == "ok"
- assert sres.payload.model_dump().get("sessionId") == self.proposed_session_id
+ with setup_app_runner(test_app_dir, "run", load = True) as ar:
+ await init_app_session(ar, session_id=self.proposed_session_id)
ev_req = EventRequest(type="event", payload=WriterEvent(
type="wf-number-change",
instancePath=self.numberinput_instance_path,
@@ -133,20 +98,8 @@ async def test_valid_event(self, setup_app_runner) -> None:
@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_app_runner")
async def test_async_handler(self, setup_app_runner) -> None:
- with setup_app_runner(test_app_dir, "run") as ar:
- ar.load()
- si = InitSessionRequest(
- type="sessionInit",
- payload=InitSessionRequestPayload(
- cookies={},
- headers={},
- proposedSessionId=self.proposed_session_id
- )
- )
- sres = await ar.dispatch_message(None, si)
- assert sres.status == "ok"
- assert sres.payload.model_dump().get("sessionId") == self.proposed_session_id
-
+ with setup_app_runner(test_app_dir, "run", load = True) as ar:
+ await init_app_session(ar, session_id=self.proposed_session_id)
# Firing an event to bypass "initial" state mutations
ev_req = EventRequest(type="event", payload=WriterEvent(
type="wf-number-change",
@@ -167,17 +120,8 @@ async def test_async_handler(self, setup_app_runner) -> None:
@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_app_runner")
async def test_bad_event_handler(self, setup_app_runner) -> None:
- with setup_app_runner(test_app_dir, "run") as ar:
- ar.load()
- si = InitSessionRequest(
- type="sessionInit",
- payload=InitSessionRequestPayload(
- cookies={},
- headers={},
- proposedSessionId=self.proposed_session_id
- )
- )
- await ar.dispatch_message(None, si)
+ with setup_app_runner(test_app_dir, "run", load = True) as ar:
+ await init_app_session(ar, session_id=self.proposed_session_id)
bad_button_instance_path = [
{"componentId": "root", "instanceNumber": 0},
{"componentId": "28a2212b-bc58-4398-8a72-2554e5296490", "instanceNumber": 0},
@@ -238,3 +182,25 @@ async def test_code_update(self, setup_app_runner) -> None:
mail = list(si_res.payload.model_dump().get("mail"))
assert mail[0].get("payload").get("message") == "188542\n"
+
+ @pytest.mark.asyncio
+ @pytest.mark.usefixtures("setup_app_runner")
+ async def test_handle_event_should_return_result_of_event_handler_execution(self, setup_app_runner):
+ """
+ Tests that an event handler should result the result of function execution in
+ payload.result['result'].
+ """
+ # Given
+ ar: AppRunner
+ with setup_app_runner(test_app_dir, 'run', load=True) as ar:
+ session_id = await init_app_session(ar)
+
+ # When
+ res = await ar.handle_event(session_id, WriterEvent(
+ type='click',
+ instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}],
+ payload={})
+ )
+
+ # Then
+ assert res.payload.result['result'] is not None
diff --git a/tests/backend/test_middleware.py b/tests/backend/test_middleware.py
new file mode 100644
index 000000000..771087983
--- /dev/null
+++ b/tests/backend/test_middleware.py
@@ -0,0 +1,111 @@
+import pytest
+from writer.app_runner import AppRunner
+from writer.ss_types import WriterEvent
+
+from backend import test_app_dir
+from backend.fixtures.app_runner_fixtures import init_app_session
+
+
+@pytest.mark.asyncio
+@pytest.mark.usefixtures("setup_app_runner")
+async def test_middleware_should_apply_on_every_event_handler_invocation(setup_app_runner):
+ """
+ Tests that a middleware executes before an event
+ """
+ # Given
+ ar: AppRunner
+ with setup_app_runner(test_app_dir, 'run', load=True) as ar:
+ session_id = await init_app_session(ar)
+
+ # When
+ res = await ar.handle_event(session_id, WriterEvent(
+ type='click',
+ instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}],
+ payload={})
+ )
+ res.payload.mutations['+counter_middleware'] = 1
+
+ # Then
+ full_state = await ar.handle_state_content(session_id)
+ assert full_state.payload.state['counter'] == 3
+ assert full_state.payload.state['counter_middleware'] == 1
+
+
+@pytest.mark.asyncio
+@pytest.mark.usefixtures("setup_app_runner")
+async def test_middleware_should_apply_on_multi_event_handler_invocation(setup_app_runner):
+ """
+ Tests that a middleware executes twice after 2 different events
+ """
+ # Given
+ ar: AppRunner
+ with setup_app_runner(test_app_dir, 'run', load=True) as ar:
+ session_id = await init_app_session(ar)
+
+ # When
+ res = await ar.handle_event(session_id, WriterEvent(
+ type='click',
+ instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}],
+ payload={})
+ )
+ res.payload.mutations['+counter_middleware'] = 1
+
+ res = await ar.handle_event(session_id, WriterEvent(
+ type='wf-option-change',
+ instancePath=[{'componentId': '2e46c38b-6405-42ad-ad9c-d237a53a7d30', 'instanceNumber': 0}],
+ payload='ar')
+ )
+ res.payload.mutations['+counter_middleware'] = 2
+
+ # Then
+ full_state = await ar.handle_state_content(session_id)
+ assert full_state.payload.state['counter_middleware'] == 2
+
+
+@pytest.mark.asyncio
+@pytest.mark.usefixtures("setup_app_runner")
+async def test_middleware_should_apply_after_event_handler_invocation(setup_app_runner):
+ """
+ Test that a middleware executes after an event
+ """
+ # Given
+ ar: AppRunner
+ with setup_app_runner(test_app_dir, 'run', load=True) as ar:
+ session_id = await init_app_session(ar)
+
+ # When
+ res = await ar.handle_event(session_id, WriterEvent(
+ type='click',
+ instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}],
+ payload={})
+ )
+ res.payload.mutations['+counter_post_middleware'] = 1
+
+ # Then
+ full_state = await ar.handle_state_content(session_id)
+ assert full_state.payload.state['counter'] == 3
+ assert full_state.payload.state['counter_post_middleware'] == 1
+
+
+@pytest.mark.asyncio
+@pytest.mark.usefixtures("setup_app_runner")
+async def test_middleware_should_apply_on_middleware_without_yield(setup_app_runner):
+ """
+ Test that a middleware executes after an event
+ """
+ # Given
+ ar: AppRunner
+ with setup_app_runner(test_app_dir, 'run', load=True) as ar:
+ session_id = await init_app_session(ar)
+
+ # When
+ res = await ar.handle_event(session_id, WriterEvent(
+ type='click',
+ instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}],
+ payload={})
+ )
+ res.payload.mutations['+counter_middleware_without_yield'] = 1
+
+ # Then
+ full_state = await ar.handle_state_content(session_id)
+ assert full_state.payload.state['counter_middleware_without_yield'] == 1
diff --git a/tests/backend/testapp/main.py b/tests/backend/testapp/main.py
index 171ac01b0..53df2ac09 100644
--- a/tests/backend/testapp/main.py
+++ b/tests/backend/testapp/main.py
@@ -10,6 +10,20 @@
import writer.core
+@wf.middleware()
+def my_middleware(state):
+ state['counter_middleware'] += 1
+ yield
+
+@wf.middleware()
+def no_yield_middleware(state):
+ state['counter_middleware_without_yield'] += 1
+
+@wf.middleware()
+def post_middleware(state):
+ yield
+ state['counter_post_middleware'] += 1
+
@wf.session_verifier
def check_headers(headers):
if headers.get("x-fail") is not None:
@@ -43,6 +57,7 @@ def update_cities(state, payload):
def increment(state):
state["counter"] += 1*my_var
+ return 1
# EVENT HANDLERS
@@ -220,6 +235,9 @@ def _get_altair_chart():
"min_weight": 300,
},
"counter": 0,
+ "counter_middleware": 0,
+ "counter_post_middleware": 0,
+ "counter_middleware_without_yield": 0,
"metrics": {},
"b": {
"pet_count": 8
diff --git a/tests/conftest.py b/tests/conftest.py
index 21036f36b..f75f0b495 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,3 +1,10 @@
+import contextlib
+from typing import Literal
+
+import pytest
+from writer.app_runner import AppRunner
+
+
def pytest_collection_modifyitems(config, items):
if not config.getoption("--full-run"):
deselected = []
@@ -15,3 +22,32 @@ def pytest_addoption(parser):
parser.addoption(
"--full-run", action="store_true", default=False, help="Include explicit-marked tests in the run (those are exluded from regular runs)"
)
+
+@pytest.fixture
+def setup_app_runner():
+ @contextlib.contextmanager
+ def _manage_launch_args(app_dir: str, app_command: Literal["run", "edit"], load: bool = False):
+ """
+ Fixture to instantiate a writer application for testing.
+
+ >>> with setup_app_runner("app_dir", "run", load=True) as ar:
+ >>> pass
+
+ When the load flag is True, the application is loaded.
+
+ >>> with setup_app_runner("app_dir", "run", load=True) as ar:
+ >>> pass
+
+ :param app_dir: the folder that contains the application
+ :param app_command: the execution mode of the application, either edit or run
+ :param load: load the application if True
+ """
+ ar = AppRunner(app_dir, app_command)
+ try:
+ if load is True:
+ ar.load()
+
+ yield ar
+ finally:
+ ar.shut_down()
+ return _manage_launch_args