diff --git a/quam/core/quam_classes.py b/quam/core/quam_classes.py index a0d65969..85ae02a5 100644 --- a/quam/core/quam_classes.py +++ b/quam/core/quam_classes.py @@ -81,7 +81,9 @@ def convert_dict_and_list(value, cls_or_obj=None, attr=None): return value -def sort_quam_components(components: List["QuamComponent"], max_attempts=5) -> List["QuamComponent"]: +def sort_quam_components( + components: List["QuamComponent"], max_attempts=5 +) -> List["QuamComponent"]: """Sort QuamComponent objects based on their config_settings. Args: @@ -211,7 +213,8 @@ def __set__(self, instance, value): if "parent" in instance.__dict__ and instance.__dict__["parent"] is not value: cls = instance.__class__.__name__ raise AttributeError( - f"Cannot overwrite parent attribute of {cls}. " f"To modify {cls}.parent, first set {cls}.parent = None" + f"Cannot overwrite parent attribute of {cls}. " + f"To modify {cls}.parent, first set {cls}.parent = None" ) instance.__dict__["parent"] = value @@ -249,7 +252,10 @@ def __init__(self): "Please create a subclass and make it a dataclass." ) else: - raise TypeError(f"Cannot instantiate {self.__class__.__name__}. " "Please make it a dataclass.") + raise TypeError( + f"Cannot instantiate {self.__class__.__name__}. " + "Please make it a dataclass." + ) def _get_attr_names(self) -> List[str]: """Get names of all dataclass attributes of this object. @@ -280,7 +286,9 @@ def get_attr_name(self, attr_val: Any) -> str: return attr_name else: raise AttributeError( - "Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}" + "Could not find name corresponding to attribute.\n" + f"attribute: {attr_val}\n" + f"obj: {self}" ) def _attr_val_is_default(self, attr: str, val: Any) -> bool: @@ -334,6 +342,17 @@ def _val_matches_attr_annotation(cls, attr: str, val: Any) -> bool: return isinstance(val, (list, QuamList)) return type(val) == required_type + def get_metadata(self, attr=None): + if isinstance(self, QuamRoot): + return self._metadata.setdefault("#/", {}) + + reference = self.get_reference(attr=attr) + if reference is None: + raise AttributeError( + "Unable to extract reference path. Parent must be defined for {self}" + ) + return self._root._metadata.setdefault(reference, {}) + def get_reference(self, attr=None) -> Optional[str]: """Get the reference path of this object. @@ -346,13 +365,17 @@ def get_reference(self, attr=None) -> Optional[str]: """ if self.parent is None: - raise AttributeError("Unable to extract reference path. Parent must be defined for {self}") + raise AttributeError( + "Unable to extract reference path. Parent must be defined for {self}" + ) reference = f"{self.parent.get_reference()}/{self.parent.get_attr_name(self)}" if attr is not None: reference = f"{reference}/{attr}" return reference - def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]: + def get_attrs( + self, follow_references: bool = False, include_defaults: bool = True + ) -> Dict[str, Any]: """Get all attributes and corresponding values of this object. Args: @@ -376,10 +399,16 @@ def get_attrs(self, follow_references: bool = False, include_defaults: bool = Tr attrs = {attr: getattr(self, attr) for attr in attr_names} if not include_defaults: - attrs = {attr: val for attr, val in attrs.items() if not self._attr_val_is_default(attr, val)} + attrs = { + attr: val + for attr, val in attrs.items() + if not self._attr_val_is_default(attr, val) + } return attrs - def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]: + def to_dict( + self, follow_references: bool = False, include_defaults: bool = False + ) -> Dict[str, Any]: """Convert this object to a dictionary. Args: @@ -396,7 +425,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals `"__class__"` key will be added to the dictionary. This is to ensure that the object can be reconstructed when loading from a file. """ - attrs = self.get_attrs(follow_references=follow_references, include_defaults=include_defaults) + attrs = self.get_attrs( + follow_references=follow_references, include_defaults=include_defaults + ) quam_dict = {} for attr, val in attrs.items(): if isinstance(val, QuamBase): @@ -411,7 +442,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals quam_dict[attr] = val return quam_dict - def iterate_components(self, skip_elems: bool = None) -> Generator["QuamBase", None, None]: + def iterate_components( + self, skip_elems: bool = None + ) -> Generator["QuamBase", None, None]: """Iterate over all QuamBase objects in this object, including nested objects. Args: @@ -473,12 +506,15 @@ def _get_referenced_value(self, reference: str) -> Any: if string_reference.is_absolute_reference(reference) and self._root is None: warnings.warn( - f"No QuamRoot initialized, cannot retrieve reference {reference}" f" from {self.__class__.__name__}" + f"No QuamRoot initialized, cannot retrieve reference {reference}" + f" from {self.__class__.__name__}" ) return reference try: - return string_reference.get_referenced_value(self, reference, root=self._root) + return string_reference.get_referenced_value( + self, reference, root=self._root + ) except ValueError as e: try: ref = f"{self.__class__.__name__}: {self.get_reference()}" @@ -546,6 +582,7 @@ class QuamRoot(QuamBase): def __post_init__(self): QuamBase._root = self + self._metadata: Dict[str, Dict[str, Any]] = {} super().__post_init__() def __setattr__(self, name, value): @@ -586,7 +623,9 @@ def save( ignore=ignore, ) - def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]: + def to_dict( + self, follow_references: bool = False, include_defaults: bool = False + ) -> Dict[str, Any]: """Convert this object to a dictionary. Args: @@ -750,7 +789,9 @@ def __getitem__(self, i): repr = f"{self.__class__.__name__}: {self.get_reference()}" except Exception: repr = self.__class__.__name__ - raise KeyError(f"Could not get referenced value {elem} from {repr}") from e + raise KeyError( + f"Could not get referenced value {elem} from {repr}" + ) from e return elem # Overriding methods from UserDict @@ -774,7 +815,9 @@ def __repr__(self) -> str: def _get_attr_names(self): return list(self.data.keys()) - def get_attrs(self, follow_references=False, include_defaults=True) -> Dict[str, Any]: + def get_attrs( + self, follow_references=False, include_defaults=True + ) -> Dict[str, Any]: # TODO implement reference kwargs return self.data @@ -795,7 +838,9 @@ def get_attr_name(self, attr_val: Any) -> Union[str, int]: return attr_name else: raise AttributeError( - "Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}" + "Could not find name corresponding to attribute.\n" + f"attribute: {attr_val}\n" + f"obj: {self}" ) def _val_matches_attr_annotation(self, attr: str, val: Any) -> bool: @@ -841,10 +886,13 @@ def get_unreferenced_value(self, attr: str) -> bool: return self.__dict__["data"][attr] except KeyError as e: raise AttributeError( - "Cannot get unreferenced value from attribute {attr} that does not" " exist in {self}" + "Cannot get unreferenced value from attribute {attr} that does not" + " exist in {self}" ) from e - def iterate_components(self, skip_elems: Sequence[QuamBase] = None) -> Generator["QuamBase", None, None]: + def iterate_components( + self, skip_elems: Sequence[QuamBase] = None + ) -> Generator["QuamBase", None, None]: """Iterate over all QuamBase objects in this object, including nested objects. Args: @@ -969,10 +1017,14 @@ def get_attr_name(self, attr_val: Any) -> str: return str(k) else: raise AttributeError( - "Could not find name corresponding to attribute" f"attribute: {attr_val}\n" f"obj: {self}" + "Could not find name corresponding to attribute" + f"attribute: {attr_val}\n" + f"obj: {self}" ) - def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> list: + def to_dict( + self, follow_references: bool = False, include_defaults: bool = False + ) -> list: """Convert this object to a list, usually as part of a dictionary representation. Args: @@ -1006,7 +1058,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals quam_list.append(val) return quam_list - def iterate_components(self, skip_elems: List[QuamBase] = None) -> Generator["QuamBase", None, None]: + def iterate_components( + self, skip_elems: List[QuamBase] = None + ) -> Generator["QuamBase", None, None]: """Iterate over all QuamBase objects in this object, including nested objects. Args: @@ -1027,7 +1081,9 @@ def iterate_components(self, skip_elems: List[QuamBase] = None) -> Generator["Qu if isinstance(attr_val, QuamBase): yield from attr_val.iterate_components(skip_elems=skip_elems) - def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]: + def get_attrs( + self, follow_references: bool = False, include_defaults: bool = True + ) -> Dict[str, Any]: raise NotImplementedError("QuamList does not have attributes") def print_summary(self, indent: int = 0):