From 642d6266fe8e3c4db483be593f124682122eab0c Mon Sep 17 00:00:00 2001 From: Ramiro Medina <64783088+ramedina86@users.noreply.github.com> Date: Wed, 4 Dec 2024 14:33:08 +0000 Subject: [PATCH] feat: Global events --- .../workflows/WorkflowsWorkflow.vue | 4 +- src/ui/src/core/index.ts | 3 + src/writer/core.py | 120 ++++++++++-------- src/writer/serve.py | 30 ++++- src/writer/ss_types.py | 4 +- tests/backend/test_app_runner.py | 31 +++++ tests/backend/testapp/main.py | 5 + 7 files changed, 135 insertions(+), 62 deletions(-) diff --git a/src/ui/src/components/workflows/WorkflowsWorkflow.vue b/src/ui/src/components/workflows/WorkflowsWorkflow.vue index acf754999..6d2cb2166 100644 --- a/src/ui/src/components/workflows/WorkflowsWorkflow.vue +++ b/src/ui/src/components/workflows/WorkflowsWorkflow.vue @@ -142,7 +142,6 @@ import { useDragDropComponent } from "@/builder/useDragDropComponent"; import injectionKeys from "@/injectionKeys"; const renderProxiedComponent = inject(injectionKeys.renderProxiedComponent); -const instancePath = inject(injectionKeys.instancePath); const workflowComponentId = inject(injectionKeys.componentId); const rootEl: Ref = ref(null); @@ -224,9 +223,10 @@ async function handleRun() { callback: () => { isRunning.value = false; }, + handler: `$runWorkflowById_${workflowComponentId}`, }, }), - instancePath, + null, false, ); } diff --git a/src/ui/src/core/index.ts b/src/ui/src/core/index.ts index 22ed6ebc6..8368d99ea 100644 --- a/src/ui/src/core/index.ts +++ b/src/ui/src/core/index.ts @@ -316,14 +316,17 @@ export function generateCore() { ? getPayloadFromEvent(event) : null; let callback: Function; + let handler: string; if (event instanceof CustomEvent) { callback = event.detail?.callback; + handler = event.detail?.handler; } const messagePayload = async () => ({ type: event.type, instancePath, + handler, payload: await eventPayload, }); diff --git a/src/writer/core.py b/src/writer/core.py index d62b22f86..2e33cdf30 100644 --- a/src/writer/core.py +++ b/src/writer/core.py @@ -1591,7 +1591,7 @@ def _handle_binding(self, event_type, target_component, instance_path, payload) return self.evaluator.set_state(binding["stateRef"], instance_path, payload) - def _get_workflow_callable(self, workflow_key: Optional[str], workflow_id: Optional[str]): + def _get_workflow_callable(self, workflow_key: Optional[str] = None, workflow_id: Optional[str] = None): def fn(payload, context, session): execution_environment = { "payload": payload, @@ -1604,96 +1604,104 @@ def fn(payload, context, session): self.workflow_runner.run_workflow(workflow_id, execution_environment, "Workflow execution triggered on demand") return fn - def _get_handler_callable(self, target_component: Component, event_type: str) -> Optional[Callable]: - if event_type == "wf-builtin-run" and Config.mode == "edit": - return self._get_workflow_callable(None, target_component.id) - - if not target_component.handlers: - return None - handler = target_component.handlers.get(event_type) - if not handler: - return None - + def _get_handler_callable(self, handler: str) -> Optional[Callable]: if handler.startswith("$runWorkflow_"): workflow_key = handler[13:] - return self._get_workflow_callable(workflow_key, None) + return self._get_workflow_callable(workflow_key=workflow_key) + + if handler.startswith("$runWorkflowById_"): + workflow_id = handler[17:] + return self._get_workflow_callable(workflow_id=workflow_id) current_app_process = get_app_process() handler_registry = current_app_process.handler_registry callable_handler = handler_registry.find_handler_callable(handler) return callable_handler + def _get_calling_arguments(self, ev: WriterEvent, instance_path: Optional[InstancePath] = None): + context_data = self.evaluator.get_context_data(instance_path) if instance_path else {} + context_data["event"] = ev.type + return { + "state": self.session_state, + "payload": ev.payload, + "context": context_data, + "session":_event_handler_session_info(), + "ui": _event_handler_ui_manager() + } def _call_handler_callable( self, - event_type: str, - target_component: Component, - instance_path: InstancePath, - payload: Any - ) -> Any: - - handler_callable = self._get_handler_callable(target_component, event_type) - if not handler_callable: - return - - # Preparation of arguments - context_data = self.evaluator.get_context_data(instance_path) - context_data['event'] = event_type - writer_args = { - 'state': self.session_state, - 'payload': payload, - 'context': context_data, - 'session':_event_handler_session_info(), - 'ui': _event_handler_ui_manager() - } - - # Invocation of handler + handler_callable: Callable, + calling_arguments: Dict + ) -> Any: current_app_process = get_app_process() result = None captured_stdout = None with core_ui.use_component_tree(self.session.session_component_tree), \ contextlib.redirect_stdout(io.StringIO()) as f: middlewares_executors = current_app_process.middleware_registry.executors() - result = EventHandlerExecutor.invoke_with_middlewares(middlewares_executors, handler_callable, writer_args) + result = EventHandlerExecutor.invoke_with_middlewares(middlewares_executors, handler_callable, calling_arguments) captured_stdout = f.getvalue() if captured_stdout: - self.session_state.add_log_entry( - "info", - "Stdout message", - captured_stdout - ) + self.session_state.add_log_entry("info", "Stdout message", captured_stdout) return result - def handle(self, ev: WriterEvent) -> WriterEventResult: - ok = True - + def _deserialize(self, ev: WriterEvent): try: self.deser.transform(ev) - except BaseException: - ok = False + except BaseException as e: self.session_state.add_notification( - "error", "Error", f"A deserialisation error occurred when handling event '{ ev.type }'.") - self.session_state.add_log_entry("error", "Deserialisation Failed", - f"The data sent might be corrupt. A runtime exception was raised when deserialising event '{ ev.type }'.", traceback.format_exc()) - - result = None + "error", "Error", f"A deserialization error occurred when handling event '{ ev.type }'.") + self.session_state.add_log_entry("error", "Deserialization Failed", + f"The data sent might be corrupt. A runtime exception was raised when deserializing event '{ ev.type }'.", traceback.format_exc()) + raise e + + def _handle_global_event(self, ev: WriterEvent): + if not ev.isSafe: + error_message = "Attempted executing a global event in an unsafe context." + self.session_state.add_log_entry("error", "Forbidden operation", error_message, traceback.format_exc()) + raise PermissionError(error_message) + if not ev.handler: + raise ValueError("Handler not specified when attempting to execute global event.") + handler_callable = self._get_handler_callable(ev.handler) + calling_arguments = self._get_calling_arguments(ev, instance_path=None) + return self._call_handler_callable(handler_callable, calling_arguments) + + def _handle_component_event(self, ev: WriterEvent): + instance_path = ev.instancePath try: - instance_path = ev.instancePath target_id = instance_path[-1]["componentId"] target_component = cast(Component, self.session_component_tree.get_component(target_id)) - self._handle_binding(ev.type, target_component, instance_path, ev.payload) - result = self._call_handler_callable(ev.type, target_component, instance_path, ev.payload) - except BaseException: - ok = False + if not target_component.handlers: + return None + handler = target_component.handlers.get(ev.type) + if not handler: + return None + handler_callable = self._get_handler_callable(handler) + calling_arguments = self._get_calling_arguments(ev, instance_path) + return self._call_handler_callable(handler_callable, calling_arguments) + except BaseException as e: self.session_state.add_notification("error", "Runtime Error", f"An error occurred when processing event '{ ev.type }'.", ) self.session_state.add_log_entry("error", "Runtime Exception", f"A runtime exception was raised when processing event '{ ev.type }'.", traceback.format_exc()) + raise e + + def handle(self, ev: WriterEvent) -> WriterEventResult: + try: + if not ev.isSafe and ev.handler is not None: + raise PermissionError("Unexpected handler set on event.") + self._deserialize(ev) + if not ev.instancePath: + return {"ok": True, "result": self._handle_global_event(ev)} + else: + return {"ok": True, "result": self._handle_component_event(ev)} + except BaseException as e: + return {"ok": False, "result": str(e)} - return {"ok": ok, "result": result} class EventHandlerExecutor: diff --git a/src/writer/serve.py b/src/writer/serve.py index 67ddf9580..459069d1b 100644 --- a/src/writer/serve.py +++ b/src/writer/serve.py @@ -48,9 +48,26 @@ logging.getLogger().setLevel(logging.INFO) +class JobVault: + + def __init__(self): + self.counter = 0 + self.vault = {} + + def generate_job_id(self): + self.counter += 1 + return self.counter + + def set(self, job_id: str, value: Any): + self.vault[job_id] = value + + def get(self, job_id: str): + return self.vault.get(job_id) + class WriterState(typing.Protocol): app_runner: AppRunner writer_app: bool + job_vault: JobVault is_server_static_mounted: bool meta: Union[Dict[str, Any], Callable[[], Dict[str, Any]]] # meta tags for SEO opengraph_tags: Union[Dict[str, Any], Callable[[], Dict[str, Any]]] # opengraph tags for social networks integration (facebook, discord) @@ -122,6 +139,7 @@ async def lifespan(asgi_app: FastAPI): """ app.state.writer_app = True app.state.app_runner = app_runner + app.state.job_vault = JobVault() def _get_extension_paths() -> List[str]: extensions_path = pathlib.Path(user_app_path) / "extensions" @@ -310,13 +328,19 @@ async def _handle_incoming_event(websocket: WebSocket, session_id: str, req_mess trackingId=req_message.trackingId, payload=None ) + + # Allows for global events if in edit mode (such as "Run workflow" for previewing a workflow) + + is_safe = serve_mode == "edit" res_payload: Optional[Dict[str, Any]] = None apsr: Optional[AppProcessServerResponse] = None apsr = await app_runner.handle_event( session_id, WriterEvent( - type=req_message.payload["type"], - instancePath=req_message.payload["instancePath"], - payload=req_message.payload["payload"] + type=req_message.payload.get("type"), + handler=req_message.payload.get("handler"), + isSafe=is_safe, + instancePath=req_message.payload.get("instancePath"), + payload=req_message.payload.get("payload") )) if apsr is not None and apsr.payload is not None: res_payload = typing.cast( diff --git a/src/writer/ss_types.py b/src/writer/ss_types.py index 0cad40f6b..3bd05ba76 100644 --- a/src/writer/ss_types.py +++ b/src/writer/ss_types.py @@ -101,7 +101,9 @@ class ComponentUpdateRequest(AppProcessServerRequest): class WriterEvent(BaseModel): type: str - instancePath: InstancePath + isSafe: Optional[bool] = False + handler: Optional[str] = None + instancePath: Optional[InstancePath] = None payload: Optional[Any] = None diff --git a/tests/backend/test_app_runner.py b/tests/backend/test_app_runner.py index 47ca54763..0a92429ce 100644 --- a/tests/backend/test_app_runner.py +++ b/tests/backend/test_app_runner.py @@ -209,6 +209,37 @@ async def test_bad_event_handler(self, setup_app_runner) -> None: assert ev_res.status == "ok" assert not ev_res.payload.result.get("ok") + @pytest.mark.asyncio + @pytest.mark.usefixtures("setup_app_runner") + async def test_unsafe_event(self, setup_app_runner) -> None: + with setup_app_runner(test_app_dir, "run", load = True) as ar: + await init_app_session(ar, session_id=self.proposed_session_id) + ev_req = EventRequest(type="event", payload=WriterEvent( + type="wf-built-run", + handler="nineninenine", + instancePath=None, + payload=None + )) + ev_res = await ar.dispatch_message(self.proposed_session_id, ev_req) + assert ev_res.status == "ok" + assert not ev_res.payload.result.get("ok") + + @pytest.mark.asyncio + @pytest.mark.usefixtures("setup_app_runner") + async def test_safe_global_event(self, setup_app_runner) -> None: + with setup_app_runner(test_app_dir, "run", load = True) as ar: + await init_app_session(ar, session_id=self.proposed_session_id) + ev_req = EventRequest(type="event", payload=WriterEvent( + type="wf-built-run", + isSafe=True, + handler="nineninenine", + instancePath=None, + payload=None + )) + ev_res = await ar.dispatch_message(self.proposed_session_id, ev_req) + assert ev_res.status == "ok" + assert ev_res.payload.result.get("result") == 999 + @pytest.mark.usefixtures("setup_app_runner") def test_run_code_edit(self, setup_app_runner) -> None: with setup_app_runner(test_app_dir, "run") as ar: diff --git a/tests/backend/testapp/main.py b/tests/backend/testapp/main.py index cbf88ce3d..9542bfcac 100644 --- a/tests/backend/testapp/main.py +++ b/tests/backend/testapp/main.py @@ -52,6 +52,11 @@ def update_cities(state, payload): "br": "Bristol" } + +def nineninenine(): + return 999 + + def create_text_widget(ui: WriterUIManager): with ui.find('bb4d0e86-619e-4367-a180-be28ab6059f4'): ui.Text({"text": "Hello world"})