Skip to content

Commit

Permalink
use middleware to intercept event handlers (for database session mana…
Browse files Browse the repository at this point in the history
…gement, ...)

* feat: implement middleware execution
  • Loading branch information
FabienArcellier committed Jun 19, 2024
1 parent 109e723 commit 894607a
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 19 deletions.
69 changes: 53 additions & 16 deletions src/writer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,28 @@ def __init__(self):
def register(self, middleware: Callable):
self.registry.append(middleware)

def executors(self, writer_args: dict) -> List[Generator]:
"""
Retrieves middleware ready to be executed in the form of an iterator.
"""
executors = []
for m in self.registry:
handler_args = build_writer_func_arguments(m, writer_args)
def wrapper(middleware):
it = middleware(*handler_args)
try:
next(it)
yield
next(it)
except StopIteration:
# This part manages the end of the middleware which will throw a StopIteration exception
# because there is only one yield in the middleware.
yield

executors.append(wrapper(m))

return executors

class EventHandlerRegistry:
"""
Maps functions registered as event handlers from the user app's core
Expand Down Expand Up @@ -1331,7 +1353,7 @@ def _call_handler_callable(self, event_type, target_component, instance_path, pa

# Preparation of arguments
from writer.ui import WriterUIManager
all_args = {
writer_args = {
'state': self.session_state,
'payload': payload,
'context': self.evaluator.get_context_data(instance_path),
Expand All @@ -1349,14 +1371,16 @@ def _call_handler_callable(self, event_type, target_component, instance_path, pa
captured_stdout = None
with core_ui.use_component_tree(self.session.session_component_tree), \
contextlib.redirect_stdout(io.StringIO()) as f:
# middlewares = middleware.all()
# for m in middlewares:
# middleware_request_executor(m, all_args)
middlewares_executors = current_app_process.middleware_registry.executors(writer_args)
# before executing the event
for me in middlewares_executors:
next(me)

result = handler_executor(callable_handler, all_args)
result = handler_executor(callable_handler, writer_args)

# for m in middlewares:
# middleware_response_executor(m, all_args)
# after executing the event
for me in middlewares_executors:
next(me)

captured_stdout = f.getvalue()

Expand Down Expand Up @@ -1389,7 +1413,7 @@ def handle(self, ev: WriterEvent) -> WriterEventResult:

self._handle_binding(ev.type, target_component, instance_path, ev.payload)
result = self._call_handler_callable(ev.type, target_component, instance_path, ev.payload)
except BaseException:
except BaseException as e:
ok = False
self.session_state.add_notification("error", "Runtime Error", f"An error occurred when processing event '{ ev.type }'.",
)
Expand Down Expand Up @@ -1487,7 +1511,7 @@ def reset_base_component_tree() -> None:
base_component_tree = core_ui.build_base_component_tree()


def handler_executor(callable_handler: Callable, args: dict) -> Any:
def handler_executor(callable_handler: Callable, writer_args: dict) -> Any:
"""
Runs a handler based on its signature.
Expand All @@ -1503,21 +1527,34 @@ def handler_executor(callable_handler: Callable, args: dict) -> Any:
if (not callable(callable_handler) and not is_async_handler):
raise ValueError("Invalid handler. The handler isn't a callable object.")

handler_args = inspect.getfullargspec(callable_handler).args
arg_values = []
for arg in handler_args:
if arg in args:
arg_values.append(args[arg])
handler_args = build_writer_func_arguments(callable_handler, writer_args)

if is_async_handler:
async_wrapper = _async_wrapper_internal(callable_handler, arg_values)
async_wrapper = _async_wrapper_internal(callable_handler, handler_args)
result = asyncio.run(async_wrapper)
else:
result = callable_handler(*arg_values)
result = callable_handler(*handler_args)

return result


def build_writer_func_arguments(func: Callable, writer_args: dict) -> List[Any]:
"""
Constructs the list of arguments based on the signature of the function
which can be a handler or middleware.
:param func: the function that will be called
:param writer_args: the possible arguments in writer (state, payload, ...)
"""
handler_args = inspect.getfullargspec(func).args
func_args = []
for arg in handler_args:
if arg in writer_args:
func_args.append(writer_args[arg])

return func_args


async def _async_wrapper_internal(callable_handler: Callable, arg_values: List[Any]) -> Any:
result = await callable_handler(*arg_values)
return result
Expand Down
64 changes: 61 additions & 3 deletions tests/backend/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,79 @@
@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_app_runner")
async def test_middleware_should_apply_on_every_event_handler_invocation(setup_app_runner):
# 08/06/2024 : the code that executes middleware is not yet implemented
pytest.skip('this test is not implemented')
"""
Tests that a middleware executes before an event
"""
# Given
ar: AppRunner
with setup_app_runner(test_app_dir, 'run', load=True) as ar:
session_id = await init_app_session(ar)

# When
await ar.handle_event(session_id, WriterEvent(
res = await ar.handle_event(session_id, WriterEvent(
type='click',
instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}],
payload={})
)
res.payload.mutations['+counter_middleware'] = 1

# Then
full_state = await ar.handle_state_content(session_id)
assert full_state.payload.state['counter'] == 3
assert full_state.payload.state['counter_middleware'] == 1


@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_app_runner")
async def test_middleware_should_apply_on_multi_event_handler_invocation(setup_app_runner):
"""
Tests that a middleware executes twice after 2 different events
"""
# Given
ar: AppRunner
with setup_app_runner(test_app_dir, 'run', load=True) as ar:
session_id = await init_app_session(ar)

# When
res = await ar.handle_event(session_id, WriterEvent(
type='click',
instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}],
payload={})
)
res.payload.mutations['+counter_middleware'] = 1

res = await ar.handle_event(session_id, WriterEvent(
type='wf-option-change',
instancePath=[{'componentId': '2e46c38b-6405-42ad-ad9c-d237a53a7d30', 'instanceNumber': 0}],
payload='ar')
)
res.payload.mutations['+counter_middleware'] = 2

# Then
full_state = await ar.handle_state_content(session_id)
assert full_state.payload.state['counter_middleware'] == 2


@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_app_runner")
async def test_middleware_should_apply_after_event_handler_invocation(setup_app_runner):
"""
Test that a middleware executes after an event
"""
# Given
ar: AppRunner
with setup_app_runner(test_app_dir, 'run', load=True) as ar:
session_id = await init_app_session(ar)

# When
res = await ar.handle_event(session_id, WriterEvent(
type='click',
instancePath=[{'componentId': '5c0df6e8-4dd8-4485-a244-8e9e7f4b4675', 'instanceNumber': 0}],
payload={})
)
res.payload.mutations['+counter_post_middleware'] = 1

# Then
full_state = await ar.handle_state_content(session_id)
assert full_state.payload.state['counter'] == 3
assert full_state.payload.state['counter_post_middleware'] == 1
6 changes: 6 additions & 0 deletions tests/backend/testapp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def my_middleware(state):
state['counter_middleware'] += 1
yield

@wf.middleware()
def post_middleware(state):
yield
state['counter_post_middleware'] += 1

@wf.session_verifier
def check_headers(headers):
if headers.get("x-fail") is not None:
Expand Down Expand Up @@ -226,6 +231,7 @@ def _get_altair_chart():
},
"counter": 0,
"counter_middleware": 0,
"counter_post_middleware": 0,
"metrics": {},
"b": {
"pet_count": 8
Expand Down

0 comments on commit 894607a

Please sign in to comment.