From 620f41080d37b809762410af2dfd1960c5749da2 Mon Sep 17 00:00:00 2001 From: Fabien Arcellier Date: Thu, 7 Mar 2024 15:48:11 +0100 Subject: [PATCH] feat: declare optional schema on streamsync state * fix: issue on substate assignation --- src/streamsync/__init__.py | 5 +- src/streamsync/core.py | 95 +++++++++++++++++++++++++++++++------- tests/test_core.py | 29 ++++++++++++ 3 files changed, 110 insertions(+), 19 deletions(-) diff --git a/src/streamsync/__init__.py b/src/streamsync/__init__.py index fc867c6e7..343132dd3 100644 --- a/src/streamsync/__init__.py +++ b/src/streamsync/__init__.py @@ -31,7 +31,7 @@ def pack_bytes(raw_data, mime_type: Optional[str] = None): S = TypeVar('S', bound=StreamsyncState) -def init_state(state_dict: Dict[str, Any], schema: Optional[Type[S]] = None) -> Union[S, StreamsyncState]: +def init_state(raw_state: Dict[str, Any], schema: Optional[Type[S]] = None) -> Union[S, StreamsyncState]: """ Sets the initial state, which will be used as the starting point for every session. @@ -46,6 +46,5 @@ def init_state(state_dict: Dict[str, Any], schema: Optional[Type[S]] = None) -> if not issubclass(concrete_schema, StreamsyncState): raise ValueError("Root schema must inherit from StreamsyncState") - _initial_state: S = new_initial_state(concrete_schema) - _initial_state.ingest(state_dict) + _initial_state: S = new_initial_state(concrete_schema, raw_state) return _initial_state diff --git a/src/streamsync/core.py b/src/streamsync/core.py index 04a9dca93..0b020babe 100644 --- a/src/streamsync/core.py +++ b/src/streamsync/core.py @@ -8,7 +8,8 @@ import sys import time import traceback -from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union, TypeVar, Type, Sequence, cast +from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union, TypeVar, Type, Sequence, cast, \ + Generator import urllib.request import base64 import io @@ -321,6 +322,24 @@ def to_dict(self) -> Dict[str, Any]: serialised[key] = serialised_value return serialised + def to_raw_state(self): + """ + Converts a StateProxy and its children into a python dictionary. + + >>> state = State({'a': 1, 'c': {'a': 1, 'b': 3}}) + >>> _raw_state = state._state_proxy.to_raw_state() + >>> {'a': 1, 'c': {'a': 1, 'b': 3}} + + :return: a python dictionary that represents the raw state + """ + raw_state = {} + for key, value in self.state.items(): + if isinstance(value, StateProxy): + value = value.to_raw_state() + raw_state[key] = value + + return raw_state + def get_annotations(instance) -> Dict[str, Any]: """ @@ -399,9 +418,16 @@ def __init__(self, raw_state: Dict[str, Any] = {}): def ingest(self, raw_state: Dict[str, Any]) -> None: """ hydrates a state from raw data by applying a schema when it is provided. + The existing content in the state is erased. + + + >>> state = StreamsyncState({'message': "hello world"}) + >>> state.ingest({'a': 1, 'b': 2}) + >>> {'a': 1, 'b': 2} """ self._state_proxy.state = {} for key, value in raw_state.items(): + assert not isinstance(value, StateProxy), f"state proxy datatype is not expected in ingest operation, {locals()}" self._set_state_item(key, value) def to_dict(self) -> dict: @@ -415,18 +441,37 @@ def to_dict(self) -> dict: """ return self._state_proxy.to_dict() + + def to_raw_state(self) -> dict: + """ + Converts a StateProxy and its children into a python dictionary that can be used to recreate the + state from scratch. + + >>> state = StreamsyncState({'a': 1, 'c': {'a': 1, 'b': 3}}) + >>> raw_state = state.to_raw_state() + >>> "{'a': 1, 'c': {'a': 1, 'b': 3}}" + + :return: a python dictionary that represents the raw state + """ + return self._state_proxy.to_raw_state() + def __repr__(self) -> str: return self._state_proxy.__repr__() def __getitem__(self, key: str) -> Any: - annotations = get_annotations(self) - expected_type = annotations.get(key) - if expected_type is not None and inspect.isclass(expected_type) and issubclass(expected_type, State): - return getattr(self, key) - else: - return self._state_proxy.__getitem__(key) + + # Essential to support operation like + # state['item']['a'] = state['item']['b'] + if hasattr(self, key): + value = getattr(self, key) + if isinstance(value, State): + return value + + return self._state_proxy.__getitem__(key) def __setitem__(self, key: str, raw_value: Any) -> None: + assert not isinstance(raw_value, StateProxy), f"state proxy datatype is not expected, {locals()}" + self._set_state_item(key, raw_value) def __delitem__(self, key: str) -> Any: @@ -435,12 +480,26 @@ def __delitem__(self, key: str) -> Any: def remove(self, key: str) -> Any: return self.__delitem__(key) + def items(self) -> Generator[Tuple[str, Any], None, None]: + for k, v in self._state_proxy.items(): + if isinstance(v, StateProxy): + # We don't want to expose StateProxy to the user, so + # we replace it with relative State + yield k, getattr(self, k) + else: + yield k, v + def __contains__(self, key: str) -> bool: return self._state_proxy.__contains__(key) def _set_state_item(self, key: str, value: Any): """ """ + + """ + At this level, the values that arrive are either States which encapsulate a StateProxy, or another datatype. + If there is a StateProxy, it is a fault in the code. + """ annotations = get_annotations(self) expected_type = annotations.get(key, None) expect_dict = expected_type is not None and inspect.isclass(expected_type) and issubclass(expected_type, dict) @@ -457,10 +516,11 @@ def _set_state_item(self, key: str, value: Any): state.ingest(value) self._state_proxy[key] = state._state_proxy else: - if isinstance(value, StateProxy): - value.apply_mutation_marker(recursive=True) - - self._state_proxy[key] = value + if isinstance(value, State): + value._state_proxy.apply_mutation_marker(recursive=True) + self._state_proxy[key] = value._state_proxy + else: + self._state_proxy[key] = value class StreamsyncState(State): @@ -495,11 +555,11 @@ def get_clone(self) -> 'StreamsyncState': >>> class AppSchema(StreamsyncState): >>> counter: int >>> - >>> root_state = AppSchema() + >>> root_state = AppSchema({'counter': 1}) >>> clone_state = root_state.get_clone() # instance of AppSchema """ try: - cloned_user_state = copy.deepcopy(self.user_state.state) + cloned_user_state = copy.deepcopy(self.user_state.to_raw_state()) cloned_mail = copy.deepcopy(self.mail) except BaseException: substitute_state = StreamsyncState() @@ -1251,7 +1311,7 @@ def __set__(self, instance, value): S = TypeVar("S", bound=StreamsyncState) -def new_initial_state(klass: Type[S]) -> S: +def new_initial_state(klass: Type[S], raw_state: dict) -> S: """ Initializes the initial state of the application and makes it globally accessible. @@ -1260,10 +1320,13 @@ def new_initial_state(klass: Type[S]) -> S: >>> class MyState(StreamsyncState): >>> pass >>> - >>> initial_state = new_initial_state(MyState) + >>> initial_state = new_initial_state(MyState, {}) """ global initial_state - initial_state = klass() + if raw_state is None: + raw_state = {} + + initial_state = klass(raw_state) return initial_state diff --git a/tests/test_core.py b/tests/test_core.py index ed6f65f92..5ae3cf81c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -177,6 +177,14 @@ def test_private_members(self) -> None: assert d.get("_private") is None assert d.get("_private_unserialisable") is None + def test_to_raw_state(self) -> None: + """ + Test that `to_raw_state` returns the state in its original format + """ + assert self.sp.to_raw_state() == raw_state_dict + assert self.sp_simple_dict.to_raw_state() == simple_dict + + class TestState: @@ -288,6 +296,27 @@ class ComplexSchema(State): '+app.title': 'world', } + def test_remove_then_replace_nested_dictionary_should_trigger_mutation(self): + """ + Tests that deleting a key from a substate, then replacing it, triggers the expected mutations + """ + # Assign + _state = State({"nested": {"a": 1, "b": 2, "c": {"d": 3, "e": 4}}}) + m = _state._state_proxy.get_mutations_as_dict() + + # Acts + del _state["nested"]["c"]["e"] + _state['nested']['c'] = _state['nested']['c'] + + # Assert + m = _state._state_proxy.get_mutations_as_dict() + assert m == { + '+nested.c': None, + '+nested.c.d': 3, + '-nested.c.e': None + } + assert _state.to_dict() == {"nested": {"a": 1, "b": 2, "c": {"d": 3}}} + class TestStreamsyncState: