diff --git a/src/streamsync/__init__.py b/src/streamsync/__init__.py index de17e3259..1177be510 100644 --- a/src/streamsync/__init__.py +++ b/src/streamsync/__init__.py @@ -1,10 +1,11 @@ import importlib.metadata from typing import Union, Optional, Dict, Any from streamsync.core import Readable, FileWrapper, BytesWrapper, Config -from streamsync.core import initial_state, session_manager, session_verifier +from streamsync.core import initial_state, base_component_tree, session_manager, session_verifier VERSION = importlib.metadata.version("streamsync") +base_component_tree session_manager Config session_verifier diff --git a/src/streamsync/app_runner.py b/src/streamsync/app_runner.py index 1e95e9b5d..52f488b97 100644 --- a/src/streamsync/app_runner.py +++ b/src/streamsync/app_runner.py @@ -18,7 +18,7 @@ from watchdog.observers.polling import PollingObserver from pydantic import ValidationError -from streamsync.core import ComponentManager, StreamsyncSession +from streamsync.core import ComponentTree, StreamsyncSession from streamsync.ss_types import (AppProcessServerRequest, AppProcessServerRequestPacket, AppProcessServerResponse, AppProcessServerResponsePacket, ComponentUpdateRequest, ComponentUpdateRequestPayload, EventRequest, EventResponsePayload, InitSessionRequest, InitSessionRequestPayload, InitSessionResponse, InitSessionResponsePayload, StateEnquiryRequest, StateEnquiryResponsePayload, StreamsyncEvent) import watchdog.observers @@ -69,7 +69,7 @@ def __init__(self, app_path: str, mode: str, run_code: str, - components: Dict, + bmc_components: Dict, is_app_process_server_ready: multiprocessing.synchronize.Event, is_app_process_server_failed: multiprocessing.synchronize.Event): super().__init__(name="AppProcess") @@ -78,7 +78,7 @@ def __init__(self, self.app_path = app_path self.mode = mode self.run_code = run_code - self.components = components + self.bmc_components = bmc_components self.is_app_process_server_ready = is_app_process_server_ready self.is_app_process_server_failed = is_app_process_server_failed self.logger = logging.getLogger("app") @@ -134,7 +134,7 @@ def _handle_session_init(self, payload: InitSessionRequestPayload) -> InitSessio import traceback as tb session = streamsync.session_manager.get_new_session( - payload.cookies, payload.headers, payload.proposedSessionId, self.components) + payload.cookies, payload.headers, payload.proposedSessionId) if session is None: raise MessageHandlingException("Session rejected.") @@ -149,7 +149,7 @@ def _handle_session_init(self, payload: InitSessionRequestPayload) -> InitSessio userState=user_state, sessionId=session.session_id, mail=session.session_state.mail, - components=session.component_manager.to_dict(), + components=session.session_component_tree.to_dict(), userFunctions=self._get_user_functions() ) @@ -207,6 +207,10 @@ def _handle_state_enquiry(self, session: StreamsyncSession) -> StateEnquiryRespo session.session_state.clear_mail() return res_payload + + def _handle_component_update(self, payload: ComponentUpdateRequestPayload) -> None: + import streamsync + streamsync.base_component_tree.ingest(payload.components) def _handle_message(self, session_id: str, request: AppProcessServerRequest) -> AppProcessServerResponse: """ @@ -332,14 +336,14 @@ def _main(self) -> None: if self.mode == "run": terminate_early = True - # try: - # streamsync.component_manager.ingest(self.components) - # except BaseException: - # streamsync.initial_state.add_log_entry( - # "error", "UI Components Error", "Couldn't load components. An exception was raised.", tb.format_exc()) - # if self.mode == "run": - # terminate_early = True - + try: + streamsync.base_component_tree.ingest(self.bmc_components) + except BaseException: + streamsync.initial_state.add_log_entry( + "error", "UI Components Error", "Couldn't load components. An exception was raised.", tb.format_exc()) + if self.mode == "run": + terminate_early = True + if terminate_early: self._terminate_early() return @@ -531,7 +535,7 @@ def __init__(self, app_path: str, mode: str): self.client_conn: Optional[multiprocessing.connection.Connection] = None self.app_process: Optional[AppProcess] = None self.run_code: Optional[str] = None - self.components: Optional[Dict] = None + self.bmc_components: Optional[Dict] = None self.is_app_process_server_ready = multiprocessing.Event() self.is_app_process_server_failed = multiprocessing.Event() self.app_process_listener: Optional[AppProcessListener] = None @@ -585,7 +589,7 @@ def signal_handler(sig, frame): pass self.run_code = self._load_persisted_script() - self.components = self._load_persisted_components() + self.bmc_components = self._load_persisted_components() if self.mode == "edit": self._set_observer() @@ -672,7 +676,7 @@ async def update_components(self, session_id: str, payload: ComponentUpdateReque if self.mode != "edit": raise PermissionError( "Cannot update components in non-update mode.") - self.components = payload.components + self.bmc_components = payload.components file_contents = { "metadata": { "streamsync_version": VERSION @@ -740,7 +744,7 @@ def shut_down(self) -> None: def _start_app_process(self) -> None: if self.run_code is None: raise ValueError("Cannot start app process. Code hasn't been set.") - if self.components is None: + if self.bmc_components is None: raise ValueError( "Cannot start app process. Components haven't been set.") self.is_app_process_server_ready.clear() @@ -754,7 +758,7 @@ def _start_app_process(self) -> None: app_path=self.app_path, mode=self.mode, run_code=self.run_code, - components=self.components, + bmc_components=self.bmc_components, is_app_process_server_ready=self.is_app_process_server_ready, is_app_process_server_failed=self.is_app_process_server_failed) self.app_process.start() diff --git a/src/streamsync/core.py b/src/streamsync/core.py index 31d385b10..042b8957b 100644 --- a/src/streamsync/core.py +++ b/src/streamsync/core.py @@ -505,7 +505,7 @@ def to_dict(self) -> Dict: return c_dict -class ComponentManager: +class ComponentTree: def __init__(self) -> None: self.counter: int = 0 @@ -513,6 +513,9 @@ def __init__(self) -> None: root_component = Component("root", "root", {}) self.attach(root_component) + def get_component(self, component_id: str) -> Component: + return self.components.get(component_id) + def get_descendents(self, parent_id: str) -> List[Component]: children = list(filter(lambda c: c.parentId == parent_id, self.components.values())) @@ -548,6 +551,25 @@ def to_dict(self) -> Dict: for id, component in self.components.items(): active_components[id] = component.to_dict() return active_components + + +class SessionComponentTree(ComponentTree): + + def __init__(self, base_component_tree: ComponentTree): + super().__init__() + self.base_component_tree = base_component_tree + + def get_component(self, component_id: str) -> Component: + base_component = self.base_component_tree.get_component(component_id) + if base_component: + return base_component + return self.components.get(component_id) + + def to_dict(self) -> Dict: + active_components = {} + for id, component in {**self.components, **self.base_component_tree.components}.items(): + active_components[id] = component.to_dict() + return active_components class EventDeserialiser: @@ -559,8 +581,8 @@ class EventDeserialiser: Its main goal is to deserialise incoming content in a controlled and predictable way, applying sanitisation of inputs where relevant.""" - def __init__(self, session_id: str, session_state: StreamsyncState): - self.evaluator = Evaluator(session_id, session_state) + def __init__(self, session_state: StreamsyncState, session_component_tree: SessionComponentTree): + self.evaluator = Evaluator(session_state, session_component_tree) def transform(self, ev: StreamsyncEvent) -> None: # Events without payloads are safe @@ -724,9 +746,9 @@ class Evaluator: template_regex = re.compile(r"[\\]?@{([\w\s.]*)}") - def __init__(self, session_id: str, session_state: StreamsyncState): - self.session_id = session_id + def __init__(self, session_state: StreamsyncState, session_component_tree: ComponentTree): self.ss = session_state + self.ct = session_component_tree def evaluate_field(self, instance_path: InstancePath, field_key: str, as_json=False, default_field_value="") -> Any: def replacer(matched): @@ -746,53 +768,48 @@ def replacer(matched): return json.dumps(serialised_value) return str(serialised_value) - session = session_manager.get_session(self.session_id) - if session: - component_id = instance_path[-1]["componentId"] - component = session.component_manager.components[component_id] - field_value = component.content.get(field_key) or default_field_value - replaced = self.template_regex.sub(replacer, field_value) + component_id = instance_path[-1]["componentId"] + component = self.ct.get_component(component_id) + field_value = component.content.get(field_key) or default_field_value + replaced = self.template_regex.sub(replacer, field_value) - if as_json: - return json.loads(replaced) - else: - return replaced - return None + if as_json: + return json.loads(replaced) + else: + return replaced def get_context_data(self, instance_path: InstancePath) -> Dict[str, Any]: context: Dict[str, Any] = {} - session = session_manager.get_session(self.session_id) - if session: - for i in range(len(instance_path)): - path_item = instance_path[i] - component_id = path_item["componentId"] - component = session.component_manager.components[component_id] - if component.type != "repeater": - continue - if i + 1 >= len(instance_path): - continue - repeater_instance_path = instance_path[0:i+1] - next_instance_path = instance_path[0:i+2] - instance_number = next_instance_path[-1]["instanceNumber"] - repeater_object = self.evaluate_field( - repeater_instance_path, "repeaterObject", True, """{ "a": { "desc": "Option A" }, "b": { "desc": "Option B" } }""") - key_variable = self.evaluate_field( - repeater_instance_path, "keyVariable", False, "itemId") - value_variable = self.evaluate_field( - repeater_instance_path, "valueVariable", False, "item") - - repeater_items: List[Tuple[Any, Any]] = [] - if isinstance(repeater_object, dict): - repeater_items = list(repeater_object.items()) - elif isinstance(repeater_object, list): - repeater_items = [(k, v) - for (k, v) in enumerate(repeater_object)] - else: - raise ValueError( - "Cannot produce context. Repeater object must evaluate to a dictionary.") + for i in range(len(instance_path)): + path_item = instance_path[i] + component_id = path_item["componentId"] + component = self.ct.get_component(component_id) + if component.type != "repeater": + continue + if i + 1 >= len(instance_path): + continue + repeater_instance_path = instance_path[0:i+1] + next_instance_path = instance_path[0:i+2] + instance_number = next_instance_path[-1]["instanceNumber"] + repeater_object = self.evaluate_field( + repeater_instance_path, "repeaterObject", True, """{ "a": { "desc": "Option A" }, "b": { "desc": "Option B" } }""") + key_variable = self.evaluate_field( + repeater_instance_path, "keyVariable", False, "itemId") + value_variable = self.evaluate_field( + repeater_instance_path, "valueVariable", False, "item") + + repeater_items: List[Tuple[Any, Any]] = [] + if isinstance(repeater_object, dict): + repeater_items = list(repeater_object.items()) + elif isinstance(repeater_object, list): + repeater_items = [(k, v) + for (k, v) in enumerate(repeater_object)] + else: + raise ValueError( + "Cannot produce context. Repeater object must evaluate to a dictionary.") - context[key_variable] = repeater_items[instance_number][0] - context[value_variable] = repeater_items[instance_number][1] + context[key_variable] = repeater_items[instance_number][0] + context[value_variable] = repeater_items[instance_number][1] return context @@ -876,7 +893,7 @@ class StreamsyncSession: Represents a session. """ - def __init__(self, session_id: str, cookies: Optional[Dict[str, str]], headers: Optional[Dict[str, str]], components: Optional[Dict[str, Any]]) -> None: + def __init__(self, session_id: str, cookies: Optional[Dict[str, str]], headers: Optional[Dict[str, str]]) -> None: self.session_id = session_id self.cookies = cookies self.headers = headers @@ -884,14 +901,8 @@ def __init__(self, session_id: str, cookies: Optional[Dict[str, str]], headers: new_state = StreamsyncState.get_new() new_state.user_state.mutated = set() self.session_state = new_state + self.session_component_tree = SessionComponentTree(base_component_tree) self.event_handler = EventHandler(self) - self.component_manager = ComponentManager() - if components: - try: - self.component_manager.ingest(components) - except BaseException: - self.session_state.add_log_entry( - "error", "UI Components Error", "Couldn't load components. An exception was raised.", tb.format_exc()) def update_last_active_timestamp(self) -> None: self.last_active_timestamp = int(time.time()) @@ -941,7 +952,7 @@ def _check_proposed_session_id(self, proposed_session_id: Optional[str]) -> bool return True return False - def get_new_session(self, cookies: Optional[Dict] = None, headers: Optional[Dict] = None, proposed_session_id: Optional[str] = None, components: Optional[Dict[str, Any]] = None) -> Optional[StreamsyncSession]: + def get_new_session(self, cookies: Optional[Dict] = None, headers: Optional[Dict] = None, proposed_session_id: Optional[str] = None) -> Optional[StreamsyncSession]: if not self._check_proposed_session_id(proposed_session_id): return None if not self._verify_before_new_session(cookies, headers): @@ -952,7 +963,7 @@ def get_new_session(self, cookies: Optional[Dict] = None, headers: Optional[Dict else: new_id = proposed_session_id new_session = StreamsyncSession( - new_id, cookies, headers, components) + new_id, cookies, headers) self.sessions[new_id] = new_session return new_session @@ -990,8 +1001,9 @@ class EventHandler: def __init__(self, session: StreamsyncSession) -> None: self.session = session self.session_state = session.session_state - self.deser = EventDeserialiser(self.session.session_id, self.session_state) - self.evaluator = Evaluator(self.session.session_id, self.session_state) + self.session_component_tree = session.session_component_tree + self.deser = EventDeserialiser(self.session_state, self.session_component_tree) + self.evaluator = Evaluator(self.session_state, self.session_component_tree) def _handle_binding(self, event_type, target_component, instance_path, payload) -> None: @@ -1088,7 +1100,7 @@ def handle(self, ev: StreamsyncEvent) -> StreamsyncEventResult: try: instance_path = ev.instancePath target_id = instance_path[-1]["componentId"] - target_component = self.session.component_manager.components[target_id] + target_component = self.session_component_tree.get_component(target_id) self._handle_binding(ev.type, target_component, instance_path, ev.payload) result = self._call_handler_callable( @@ -1105,6 +1117,7 @@ def handle(self, ev: StreamsyncEvent) -> StreamsyncEventResult: state_serialiser = StateSerialiser() initial_state = StreamsyncState() +base_component_tree = ComponentTree() session_manager = SessionManager() diff --git a/tests/test_core.py b/tests/test_core.py index 9b561a05a..f1d1783c3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,7 +3,7 @@ from typing import Dict import numpy as np -from streamsync.core import (BytesWrapper, ComponentManager, Evaluator, EventDeserialiser, +from streamsync.core import (BytesWrapper, ComponentTree, Evaluator, EventDeserialiser, FileWrapper, SessionManager, StateProxy, StateSerialiser, StateSerialiserException, StreamsyncState) import streamsync as ss from streamsync.ss_types import StreamsyncEvent @@ -49,7 +49,7 @@ ss.Config.is_mail_enabled_for_log = True ss.init_state(raw_state_dict) session = ss.session_manager.get_new_session() -session.component_manager.ingest(sc) +session.session_component_tree.ingest(sc) class TestStateProxy: @@ -216,18 +216,18 @@ def test_unpickable_members(self) -> None: json.dumps(cloned.mail) -class TestComponentManager: +class TestComponentTree: - cm = ComponentManager() + ct = ComponentTree() def test_ingest(self) -> None: - self.cm.ingest(sc) - d = self.cm.to_dict() + self.ct.ingest(sc) + d = self.ct.to_dict() assert d.get( "84378aea-b64c-49a3-9539-f854532279ee").get("type") == "header" def test_descendents(self) -> None: - desc = self.cm.get_descendents("root") + desc = self.ct.get_descendents("root") desc_ids = list(map(lambda x: x.id, desc)) assert "84378aea-b64c-49a3-9539-f854532279ee" in desc_ids assert "bb4d0e86-619e-4367-a180-be28ab6059f4" in desc_ids @@ -238,7 +238,8 @@ class TestEventDeserialiser: root_instance_path = [{"componentId": "root", "instanceNumber": 0}] session_state = StreamsyncState(raw_state_dict) - ed = EventDeserialiser(session.session_id, session_state) + component_tree = session.session_component_tree + ed = EventDeserialiser(session_state, component_tree) def test_unknown_no_payload(self) -> None: ev = StreamsyncEvent( @@ -599,7 +600,8 @@ def test_evaluate_field_simple(self) -> None: st = StreamsyncState({ "counter": 8 }) - e = Evaluator(session.session_id, st) + ct = session.session_component_tree + e = Evaluator(st, ct) evaluated = e.evaluate_field(instance_path, "text") assert evaluated == "The counter is 8" @@ -623,7 +625,8 @@ def test_evaluate_field_repeater(self) -> None: "ts": "TypeScript" } }) - e = Evaluator(session.session_id, st) + ct = session.session_component_tree + e = Evaluator(st, ct) assert e.evaluate_field( instance_path_0, "text") == "The id is c and the name is C" assert e.evaluate_field( @@ -634,7 +637,8 @@ def test_set_state(self) -> None: {"componentId": "root", "instanceNumber": 0} ] st = StreamsyncState(raw_state_dict) - e = Evaluator(session.session_id, st) + ct = session.session_component_tree + e = Evaluator(st, ct) e.set_state("name", instance_path, "Roger") e.set_state("dynamic_prop", instance_path, "height") e.set_state("features[dynamic_prop]", instance_path, "toddler height") @@ -648,7 +652,8 @@ def test_evaluate_expression(self) -> None: {"componentId": "root", "instanceNumber": 0} ] st = StreamsyncState(raw_state_dict) - e = Evaluator(session.session_id, st) + ct = session.session_component_tree + e = Evaluator(st, ct) assert e.evaluate_expression("features.eyes", instance_path) == "green" assert e.evaluate_expression("best_feature", instance_path) == "eyes" assert e.evaluate_expression("features[best_feature]", instance_path) == "green"