Skip to content

Commit

Permalink
use middleware to intercept event handlers (for database session mana…
Browse files Browse the repository at this point in the history
…gement, ...)

* feat: use pattern to handle exceptions
* feat: allow middleware without yield
  • Loading branch information
FabienArcellier committed Jun 24, 2024
1 parent 410cc1b commit 01712b7
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 57 deletions.
1 change: 0 additions & 1 deletion src/writer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def middleware():
>>> state['processing'] -= 1
"""
def inner(func):
# enregistre la fonction en tant que middlewares
_app_process = get_app_process()
_app_process.middleware_registry.register(func)

Expand Down
96 changes: 65 additions & 31 deletions src/writer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,36 +743,58 @@ def call_frontend_function(self, module_key: str, function_name: str, args: List
"args": args
})

class MiddlewareExecutor():
"""
A MiddlewareExecutor executes middleware in a controlled context. It allows writer framework
to manage different implementations of middleware.
Case 1 : A middleware is a generator, then run before and after code
>>> @wf.middleware()
>>> def my_middleware():
>>> print("before event handler")
>>> yield()
>>> print("after event handler")
Case 2 : A middleware is just a function, then run the function before
>>> @wf.middleware()
>>> def my_middleware():
>>> print("before event handler")
"""

def __init__(self, middleware: Callable):
self.middleware = middleware

@contextlib.contextmanager
def execute(self, args: dict):
middleware_args = build_writer_func_arguments(self.middleware, args)
it = self.middleware(*middleware_args)
try:
yield from it
except StopIteration:
yield
except TypeError:
yield


class MiddlewareRegistry:

def __init__(self):
self.registry = []
self.registry: List[MiddlewareExecutor] = []

def register(self, middleware: Callable):
self.registry.append(middleware)
me = MiddlewareExecutor(middleware)
self.registry.append(me)

def executors(self, writer_args: dict) -> List[Generator]:
"""
Retrieves middleware ready to be executed in the form of an iterator.
def executors(self) -> List[MiddlewareExecutor]:
"""
executors = []
for m in self.registry:
handler_args = build_writer_func_arguments(m, writer_args)
def wrapper(middleware):
it = middleware(*handler_args)
try:
next(it)
yield
next(it)
except StopIteration:
# This part manages the end of the middleware which will throw a StopIteration exception
# because there is only one yield in the middleware.
yield

executors.append(wrapper(m))
Retrieves middlewares prepared for execution
return executors
>>> executors = middleware_registry.executors()
>>> result = handle_with_middlewares_executor(executors, lambda state: pass, {'state': {}, 'payload': {}})
"""
return self.registry

class EventHandlerRegistry:
"""
Expand Down Expand Up @@ -1382,17 +1404,9 @@ def _call_handler_callable(
captured_stdout = None
with core_ui.use_component_tree(self.session.session_component_tree), \
contextlib.redirect_stdout(io.StringIO()) as f:
middlewares_executors = current_app_process.middleware_registry.executors(writer_args)
# before executing the event
for me in middlewares_executors:
next(me)

result = handler_executor(callable_handler, writer_args)

# after executing the event
for me in middlewares_executors:
next(me)
middlewares_executors = current_app_process.middleware_registry.executors()

result = handle_with_middlewares_executor(middlewares_executors, callable_handler, writer_args)
captured_stdout = f.getvalue()

if captured_stdout:
Expand Down Expand Up @@ -1548,6 +1562,26 @@ def handler_executor(callable_handler: Callable, writer_args: dict) -> Any:

return result

def handle_with_middlewares_executor(middlewares_executors: List[MiddlewareExecutor], callable_handler: Callable, writer_args: dict) -> Any:
"""
Runs the middlewares then the handler. This function allows you to manage exceptions that are triggered in middleware
:param middlewares_executors: The list of middleware to run
:param callable_handler: The target handler
>>> @wf.middleware()
>>> def my_middleware(state, payload, context, session, ui):
>>> yield
>>> executor = MiddlewareExecutor(my_middleware, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None})
>>> handle_with_middlewares_executor([executor], my_handler, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None}
"""
if len(middlewares_executors) == 0:
return handler_executor(callable_handler, writer_args)
else:
executor = middlewares_executors[0]
with executor.execute(writer_args):
return handle_with_middlewares_executor(middlewares_executors[1:], callable_handler, writer_args)

def build_writer_func_arguments(func: Callable, writer_args: dict) -> List[Any]:
"""
Expand Down
21 changes: 0 additions & 21 deletions src/writer/tests.py

This file was deleted.

8 changes: 4 additions & 4 deletions tests/backend/fixtures/app_runner_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
from writer.app_runner import AppRunner
from writer.ss_types import InitSessionRequestPayload

FIXED_SESSION_ID = "0000000000000000000000000000000000000000000000000000000000000000"
FIXED_SESSION_ID = "0000000000000000000000000000000000000000000000000000000000000000" # Compliant session number

async def init_app_session(app_runner: AppRunner,
session_id: str = None,
cookies: Optional[dict] = None,
headers: Optional[dict] = None) -> str:
"""
Creates a session in writer framework for automatic testing
Fixture to initialize a session and be able to use it in tests.
Creates a session with a random ID.
If the `session_id` is missing from the parameters, the fixture creates a session with a random ID.
>>> with setup_app_runner(test_app_dir, 'run') as ar:
>>> # When
>>> ar.load()
>>> session_id = await init_app_session(ar)
Creates a session with a fixed identifier.
If the `session_id` is missing from the parameters, the fixture creates a session for this identifier.
>>> session_id = await init_app_session(ar, session_id=FIXED_SESSION_ID)
"""
Expand Down
22 changes: 22 additions & 0 deletions tests/backend/test_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,25 @@ async def test_code_update(self, setup_app_runner) -> None:
mail = list(si_res.payload.model_dump().get("mail"))

assert mail[0].get("payload").get("message") == "188542\n"

@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_app_runner")
async def test_handle_event_should_return_result_of_event_handler_execution(self, setup_app_runner):
"""
Tests that an event handler should result the result of function execution in
payload.result['result'].
"""
# Given
ar: AppRunner
with setup_app_runner(test_app_dir, 'run', load=True) as ar:
session_id = await init_app_session(ar)

# When
res = await ar.handle_event(session_id, WriterEvent(
type='click',
instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}],
payload={})
)

# Then
assert res.payload.result['result'] is not None
24 changes: 24 additions & 0 deletions tests/backend/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,27 @@ async def test_middleware_should_apply_after_event_handler_invocation(setup_app_
full_state = await ar.handle_state_content(session_id)
assert full_state.payload.state['counter'] == 3
assert full_state.payload.state['counter_post_middleware'] == 1


@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_app_runner")
async def test_middleware_should_apply_on_middleware_without_yield(setup_app_runner):
"""
Test that a middleware executes after an event
"""
# Given
ar: AppRunner
with setup_app_runner(test_app_dir, 'run', load=True) as ar:
session_id = await init_app_session(ar)

# When
res = await ar.handle_event(session_id, WriterEvent(
type='click',
instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}],
payload={})
)
res.payload.mutations['+counter_middleware_without_yield'] = 1

# Then
full_state = await ar.handle_state_content(session_id)
assert full_state.payload.state['counter_middleware_without_yield'] == 1
6 changes: 6 additions & 0 deletions tests/backend/testapp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def my_middleware(state):
state['counter_middleware'] += 1
yield

@wf.middleware()
def no_yield_middleware(state):
state['counter_middleware_without_yield'] += 1

@wf.middleware()
def post_middleware(state):
yield
Expand Down Expand Up @@ -53,6 +57,7 @@ def update_cities(state, payload):

def increment(state):
state["counter"] += 1*my_var
return 1

# EVENT HANDLERS

Expand Down Expand Up @@ -232,6 +237,7 @@ def _get_altair_chart():
"counter": 0,
"counter_middleware": 0,
"counter_post_middleware": 0,
"counter_middleware_without_yield": 0,
"metrics": {},
"b": {
"pet_count": 8
Expand Down

0 comments on commit 01712b7

Please sign in to comment.