From ea210ac7a3004bbd61ee38a05ee7021eaa083454 Mon Sep 17 00:00:00 2001 From: Serwan Asaad Date: Sun, 7 Jul 2024 09:46:54 +0300 Subject: [PATCH 1/2] state updates --- quam/core/quam_classes.py | 100 ++++++++++++++++++++++------- quam/serialisation/json.py | 4 ++ quam/utils/state_update_tracker.py | 46 +++++++++++++ 3 files changed, 127 insertions(+), 23 deletions(-) create mode 100644 quam/utils/state_update_tracker.py diff --git a/quam/core/quam_classes.py b/quam/core/quam_classes.py index a0d65969..ce153a46 100644 --- a/quam/core/quam_classes.py +++ b/quam/core/quam_classes.py @@ -32,6 +32,7 @@ generate_config_final_actions, ) from quam.core.quam_instantiation import instantiate_quam_class +from quam.utils.state_update_tracker import StateTracker from .qua_config_template import qua_config_template @@ -81,7 +82,9 @@ def convert_dict_and_list(value, cls_or_obj=None, attr=None): return value -def sort_quam_components(components: List["QuamComponent"], max_attempts=5) -> List["QuamComponent"]: +def sort_quam_components( + components: List["QuamComponent"], max_attempts=5 +) -> List["QuamComponent"]: """Sort QuamComponent objects based on their config_settings. Args: @@ -211,7 +214,8 @@ def __set__(self, instance, value): if "parent" in instance.__dict__ and instance.__dict__["parent"] is not value: cls = instance.__class__.__name__ raise AttributeError( - f"Cannot overwrite parent attribute of {cls}. " f"To modify {cls}.parent, first set {cls}.parent = None" + f"Cannot overwrite parent attribute of {cls}. " + f"To modify {cls}.parent, first set {cls}.parent = None" ) instance.__dict__["parent"] = value @@ -249,7 +253,10 @@ def __init__(self): "Please create a subclass and make it a dataclass." ) else: - raise TypeError(f"Cannot instantiate {self.__class__.__name__}. " "Please make it a dataclass.") + raise TypeError( + f"Cannot instantiate {self.__class__.__name__}. " + "Please make it a dataclass." + ) def _get_attr_names(self) -> List[str]: """Get names of all dataclass attributes of this object. @@ -280,7 +287,9 @@ def get_attr_name(self, attr_val: Any) -> str: return attr_name else: raise AttributeError( - "Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}" + "Could not find name corresponding to attribute.\n" + f"attribute: {attr_val}\n" + f"obj: {self}" ) def _attr_val_is_default(self, attr: str, val: Any) -> bool: @@ -346,13 +355,17 @@ def get_reference(self, attr=None) -> Optional[str]: """ if self.parent is None: - raise AttributeError("Unable to extract reference path. Parent must be defined for {self}") + raise AttributeError( + "Unable to extract reference path. Parent must be defined for {self}" + ) reference = f"{self.parent.get_reference()}/{self.parent.get_attr_name(self)}" if attr is not None: reference = f"{reference}/{attr}" return reference - def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]: + def get_attrs( + self, follow_references: bool = False, include_defaults: bool = True + ) -> Dict[str, Any]: """Get all attributes and corresponding values of this object. Args: @@ -376,10 +389,16 @@ def get_attrs(self, follow_references: bool = False, include_defaults: bool = Tr attrs = {attr: getattr(self, attr) for attr in attr_names} if not include_defaults: - attrs = {attr: val for attr, val in attrs.items() if not self._attr_val_is_default(attr, val)} + attrs = { + attr: val + for attr, val in attrs.items() + if not self._attr_val_is_default(attr, val) + } return attrs - def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]: + def to_dict( + self, follow_references: bool = False, include_defaults: bool = False + ) -> Dict[str, Any]: """Convert this object to a dictionary. Args: @@ -396,7 +415,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals `"__class__"` key will be added to the dictionary. This is to ensure that the object can be reconstructed when loading from a file. """ - attrs = self.get_attrs(follow_references=follow_references, include_defaults=include_defaults) + attrs = self.get_attrs( + follow_references=follow_references, include_defaults=include_defaults + ) quam_dict = {} for attr, val in attrs.items(): if isinstance(val, QuamBase): @@ -411,7 +432,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals quam_dict[attr] = val return quam_dict - def iterate_components(self, skip_elems: bool = None) -> Generator["QuamBase", None, None]: + def iterate_components( + self, skip_elems: bool = None + ) -> Generator["QuamBase", None, None]: """Iterate over all QuamBase objects in this object, including nested objects. Args: @@ -473,12 +496,15 @@ def _get_referenced_value(self, reference: str) -> Any: if string_reference.is_absolute_reference(reference) and self._root is None: warnings.warn( - f"No QuamRoot initialized, cannot retrieve reference {reference}" f" from {self.__class__.__name__}" + f"No QuamRoot initialized, cannot retrieve reference {reference}" + f" from {self.__class__.__name__}" ) return reference try: - return string_reference.get_referenced_value(self, reference, root=self._root) + return string_reference.get_referenced_value( + self, reference, root=self._root + ) except ValueError as e: try: ref = f"{self.__class__.__name__}: {self.get_reference()}" @@ -546,6 +572,7 @@ class QuamRoot(QuamBase): def __post_init__(self): QuamBase._root = self + self._state_tracker = StateTracker() super().__post_init__() def __setattr__(self, name, value): @@ -585,8 +612,11 @@ def save( include_defaults=include_defaults, ignore=ignore, ) + self._state_tracker.change_state(serialiser.contents) - def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]: + def to_dict( + self, follow_references: bool = False, include_defaults: bool = False + ) -> Dict[str, Any]: """Convert this object to a dictionary. Args: @@ -626,13 +656,17 @@ def load( serialiser = cls.serialiser() contents, _ = serialiser.load(filepath_or_dict) - return instantiate_quam_class( + quam_obj = instantiate_quam_class( quam_class=cls, contents=contents, fix_attrs=fix_attrs, validate_type=validate_type, ) + quam_obj._state_tracker.clear() + quam_obj._state_tracker.state = contents + return quam_obj + def generate_config(self) -> Dict[str, Any]: """Generate the QUA configuration from the QuAM object. @@ -658,6 +692,9 @@ def generate_config(self) -> Dict[str, Any]: def get_unreferenced_value(self, attr: str): return getattr(self, attr) + def print_state_changes(self): + self._state_tracker.print_state_changes(self.to_dict(), mode="save") + class QuamComponent(QuamBase): """Base class for any QuAM component class. @@ -750,7 +787,9 @@ def __getitem__(self, i): repr = f"{self.__class__.__name__}: {self.get_reference()}" except Exception: repr = self.__class__.__name__ - raise KeyError(f"Could not get referenced value {elem} from {repr}") from e + raise KeyError( + f"Could not get referenced value {elem} from {repr}" + ) from e return elem # Overriding methods from UserDict @@ -774,7 +813,9 @@ def __repr__(self) -> str: def _get_attr_names(self): return list(self.data.keys()) - def get_attrs(self, follow_references=False, include_defaults=True) -> Dict[str, Any]: + def get_attrs( + self, follow_references=False, include_defaults=True + ) -> Dict[str, Any]: # TODO implement reference kwargs return self.data @@ -795,7 +836,9 @@ def get_attr_name(self, attr_val: Any) -> Union[str, int]: return attr_name else: raise AttributeError( - "Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}" + "Could not find name corresponding to attribute.\n" + f"attribute: {attr_val}\n" + f"obj: {self}" ) def _val_matches_attr_annotation(self, attr: str, val: Any) -> bool: @@ -841,10 +884,13 @@ def get_unreferenced_value(self, attr: str) -> bool: return self.__dict__["data"][attr] except KeyError as e: raise AttributeError( - "Cannot get unreferenced value from attribute {attr} that does not" " exist in {self}" + "Cannot get unreferenced value from attribute {attr} that does not" + " exist in {self}" ) from e - def iterate_components(self, skip_elems: Sequence[QuamBase] = None) -> Generator["QuamBase", None, None]: + def iterate_components( + self, skip_elems: Sequence[QuamBase] = None + ) -> Generator["QuamBase", None, None]: """Iterate over all QuamBase objects in this object, including nested objects. Args: @@ -969,10 +1015,14 @@ def get_attr_name(self, attr_val: Any) -> str: return str(k) else: raise AttributeError( - "Could not find name corresponding to attribute" f"attribute: {attr_val}\n" f"obj: {self}" + "Could not find name corresponding to attribute" + f"attribute: {attr_val}\n" + f"obj: {self}" ) - def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> list: + def to_dict( + self, follow_references: bool = False, include_defaults: bool = False + ) -> list: """Convert this object to a list, usually as part of a dictionary representation. Args: @@ -1006,7 +1056,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals quam_list.append(val) return quam_list - def iterate_components(self, skip_elems: List[QuamBase] = None) -> Generator["QuamBase", None, None]: + def iterate_components( + self, skip_elems: List[QuamBase] = None + ) -> Generator["QuamBase", None, None]: """Iterate over all QuamBase objects in this object, including nested objects. Args: @@ -1027,7 +1079,9 @@ def iterate_components(self, skip_elems: List[QuamBase] = None) -> Generator["Qu if isinstance(attr_val, QuamBase): yield from attr_val.iterate_components(skip_elems=skip_elems) - def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]: + def get_attrs( + self, follow_references: bool = False, include_defaults: bool = True + ) -> Dict[str, Any]: raise NotImplementedError("QuamList does not have attributes") def print_summary(self, indent: int = 0): diff --git a/quam/serialisation/json.py b/quam/serialisation/json.py index f78df5e3..e2ad74b0 100644 --- a/quam/serialisation/json.py +++ b/quam/serialisation/json.py @@ -27,6 +27,7 @@ class JSONSerialiser(AbstractSerialiser): default_filename = "state.json" default_foldername = "quam" content_mapping = {} + contents = {} def _save_dict_to_json(self, contents: Dict[str, Any], path: Path): """Save a dictionary to a JSON file. @@ -104,6 +105,7 @@ def save( content_mapping = content_mapping.copy() contents = quam_obj.to_dict(include_defaults=include_defaults) + self.contents = contents.copy() # TODO This should ideally go to the QuamRoot.to_dict method for key in ignore or []: @@ -182,6 +184,8 @@ def load( else: metadata["content_mapping"][file.name] = list(file_contents.keys()) + self.contents = contents + return contents, metadata diff --git a/quam/utils/state_update_tracker.py b/quam/utils/state_update_tracker.py new file mode 100644 index 00000000..ec833e09 --- /dev/null +++ b/quam/utils/state_update_tracker.py @@ -0,0 +1,46 @@ +from typing import Any, Dict, List, Optional + + +def compare_dicts(old_dict, new_dict) -> List[dict]: # TODO + return [] + + +def value_to_str(value: Any) -> str: + if value in [str, bool, int, float, type(None)]: + return str(value) + else: + str_repr = str(value)[:20] + if len(str_repr) == 20: + str_repr += "..." + return str_repr + + +def print_state_changes(state_changes: list, indent=0): + for state_change in state_changes: + old_value_str = value_to_str(state_change["old"]) + new_value_str = value_to_str(state_change["new"]) + print(" " * indent, end="") + print(f"{state_change['path']}: {old_value_str} → {new_value_str}") + + +class StateTracker: + def __init__(self, state: Optional[Dict[str, Any]] = None): + self.last_state: Dict[str, Any] = state or {} + self.last_state_changes: list = [] + + def change_state(self, state): + self.last_state_changes = compare_dicts(self.last_state, state) + self.last_state = state + + def print_state_changes(self, new_state, indent=2, mode: str = "state update"): + print(f"State updates before last {mode}") + print_state_changes(self.last_state_changes, indent=indent) + print("") + + print(f"State updates after last {mode}") + state_changes = compare_dicts(self.last_state, new_state) + print_state_changes(state_changes, indent=indent) + + def clear(self): + self.last_state = {} + self.last_state_changes = [] From fb5e77901c1e4182e750d60caa101b6ae2c13f60 Mon Sep 17 00:00:00 2001 From: Serwan Asaad Date: Thu, 11 Jul 2024 10:29:14 +0200 Subject: [PATCH 2/2] changes --- quam/core/quam_classes.py | 10 +-- quam/utils/state_tracker.py | 101 +++++++++++++++++++++++++++++ quam/utils/state_update_tracker.py | 46 ------------- tests/utils/test_state_tracker.py | 13 ++++ 4 files changed, 120 insertions(+), 50 deletions(-) create mode 100644 quam/utils/state_tracker.py delete mode 100644 quam/utils/state_update_tracker.py create mode 100644 tests/utils/test_state_tracker.py diff --git a/quam/core/quam_classes.py b/quam/core/quam_classes.py index ce153a46..c52336f8 100644 --- a/quam/core/quam_classes.py +++ b/quam/core/quam_classes.py @@ -32,7 +32,7 @@ generate_config_final_actions, ) from quam.core.quam_instantiation import instantiate_quam_class -from quam.utils.state_update_tracker import StateTracker +from quam.utils.state_tracker import StateTracker from .qua_config_template import qua_config_template @@ -612,7 +612,7 @@ def save( include_defaults=include_defaults, ignore=ignore, ) - self._state_tracker.change_state(serialiser.contents) + self._state_tracker.update_state(serialiser.contents) def to_dict( self, follow_references: bool = False, include_defaults: bool = False @@ -663,8 +663,7 @@ def load( validate_type=validate_type, ) - quam_obj._state_tracker.clear() - quam_obj._state_tracker.state = contents + quam_obj._state_tracker.update_.state = contents return quam_obj def generate_config(self) -> Dict[str, Any]: @@ -730,6 +729,9 @@ def apply_to_config(self, config: dict) -> None: """ ... + # def print_state_changes(self): + # self._root.print_state_changes(self.to_dict(), mode="update") + @quam_dataclass class QuamDict(UserDict, QuamBase): diff --git a/quam/utils/state_tracker.py b/quam/utils/state_tracker.py new file mode 100644 index 00000000..9894b9b5 --- /dev/null +++ b/quam/utils/state_tracker.py @@ -0,0 +1,101 @@ +from typing import Any, Dict, List, Optional, Sequence, Mapping +import jsonpatch +from jsonpointer import resolve_pointer + + +class _Placeholder: + pass + + +def jsonpatch_to_mapping( + old: Mapping[str, Any], + patch: Sequence[Mapping[str, Any]], + use_preceding_hash: bool = True, +) -> Mapping[str, Mapping[str, Any]]: + diff: dict[str, dict[str, Any]] = {} + for item in patch: + op = item["op"] + if op == "replace": + old_value = resolve_pointer(old, item["path"]) + diff[item["path"]] = { + "old": old_value, + "new": item["value"], + } + elif op == "remove": + old_value = resolve_pointer(old, item["path"]) + diff[item["path"]] = { + "old": old_value, + } + elif op == "add": + diff[item["path"]] = { + "new": item["value"], + } + elif op == "copy": + new_dst_value = resolve_pointer(old, item["from"]) + old_dst_value = resolve_pointer(old, item["path"], _Placeholder()) + diff[item["path"]] = { + "new": new_dst_value, + } + if not isinstance(old_dst_value, _Placeholder): + diff[item["path"]]["old"] = old_dst_value + elif op == "move": + old_src_value = resolve_pointer(old, item["from"]) + old_dst_value = resolve_pointer(old, item["path"], _Placeholder()) + if item["from"] not in diff: + diff[item["from"]] = { + "old": old_src_value, + } + diff[item["path"]] = { + "new": old_src_value, + } + if not isinstance(old_dst_value, _Placeholder): + diff[item["path"]]["old"] = old_dst_value + + if use_preceding_hash: + diff = {f"#{key}": val for key, val in diff.items()} + return diff + + +def compare_dicts(old_dict, new_dict) -> Dict[str, Dict[str, Any]]: # TODO + json_patches = jsonpatch.make_patch(old_dict, new_dict) + json_mapping = jsonpatch_to_mapping(old_dict, json_patches) + return json_mapping + + +def value_to_str(value: Any) -> str: + if value in [str, bool, int, float, type(None)]: + return str(value) + else: + str_repr = str(value)[:40] + if len(str_repr) == 40: + str_repr += "..." + return str_repr + + +def print_state_changes(state_changes: Dict[str, Dict[str, Any]], indent=0): + for path, state_change in state_changes.items(): + old_value_str = value_to_str(state_change.get("old", None)) + new_value_str = value_to_str(state_change["new"]) + print(" " * indent, end="") + print(f"{path}: {old_value_str} → {new_value_str}") + + +class StateTracker: + def __init__(self, state: Optional[Dict[str, Any]] = None): + self.last_state: Dict[str, Any] = state or {} + self.last_state_changes: Dict[str, Dict[str, Any]] = {} + + def update_state(self, state): + self.last_state_changes = compare_dicts(self.last_state, state) + self.last_state = state + + def print_state_changes(self, new_state, indent=3, mode: str = "update"): + if self.last_state_changes: + print(f"State changes before last {mode}") + print_state_changes(self.last_state_changes, indent=indent) + print("") + + state_changes = compare_dicts(self.last_state, new_state) + if state_changes: + print(f"State changes after last {mode}") + print_state_changes(state_changes, indent=indent) diff --git a/quam/utils/state_update_tracker.py b/quam/utils/state_update_tracker.py deleted file mode 100644 index ec833e09..00000000 --- a/quam/utils/state_update_tracker.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any, Dict, List, Optional - - -def compare_dicts(old_dict, new_dict) -> List[dict]: # TODO - return [] - - -def value_to_str(value: Any) -> str: - if value in [str, bool, int, float, type(None)]: - return str(value) - else: - str_repr = str(value)[:20] - if len(str_repr) == 20: - str_repr += "..." - return str_repr - - -def print_state_changes(state_changes: list, indent=0): - for state_change in state_changes: - old_value_str = value_to_str(state_change["old"]) - new_value_str = value_to_str(state_change["new"]) - print(" " * indent, end="") - print(f"{state_change['path']}: {old_value_str} → {new_value_str}") - - -class StateTracker: - def __init__(self, state: Optional[Dict[str, Any]] = None): - self.last_state: Dict[str, Any] = state or {} - self.last_state_changes: list = [] - - def change_state(self, state): - self.last_state_changes = compare_dicts(self.last_state, state) - self.last_state = state - - def print_state_changes(self, new_state, indent=2, mode: str = "state update"): - print(f"State updates before last {mode}") - print_state_changes(self.last_state_changes, indent=indent) - print("") - - print(f"State updates after last {mode}") - state_changes = compare_dicts(self.last_state, new_state) - print_state_changes(state_changes, indent=indent) - - def clear(self): - self.last_state = {} - self.last_state_changes = [] diff --git a/tests/utils/test_state_tracker.py b/tests/utils/test_state_tracker.py new file mode 100644 index 00000000..0870cd5b --- /dev/null +++ b/tests/utils/test_state_tracker.py @@ -0,0 +1,13 @@ +from quam.utils.state_tracker import StateTracker + + +def test_empty_state_tracker(capsys): + state_tracker = StateTracker() + assert state_tracker.last_state == {} + assert state_tracker.last_state_changes == {} + + new_dict = {"hi": "bye"} + state_tracker.update_state(new_dict) + + assert state_tracker.last_state == new_dict + assert state_tracker.last_state_changes == {"#/hi": {"new": "bye"}}