diff --git a/CHANGELOG.md b/CHANGELOG.md index e95edbf7..99f46e16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,9 @@ - Added `Channel.frame_rotation_2pi` to allow for frame rotation in multiples of 2pi - Added `Channel.update_frequency` to allow for updating the frequency of a channel +### Changed +- Allow `QuamBase.get_reference(attr)` to return a reference of one of its attributes + ## [0.3.3] ### Added - Added the following parameters to `IQChannel`: `RF_frequency`, `LO_frequency`, `intermediate_frequency` diff --git a/quam/core/quam_classes.py b/quam/core/quam_classes.py index 20a55013..a0d65969 100644 --- a/quam/core/quam_classes.py +++ b/quam/core/quam_classes.py @@ -81,9 +81,7 @@ def convert_dict_and_list(value, cls_or_obj=None, attr=None): return value -def sort_quam_components( - components: List["QuamComponent"], max_attempts=5 -) -> List["QuamComponent"]: +def sort_quam_components(components: List["QuamComponent"], max_attempts=5) -> List["QuamComponent"]: """Sort QuamComponent objects based on their config_settings. Args: @@ -213,8 +211,7 @@ def __set__(self, instance, value): if "parent" in instance.__dict__ and instance.__dict__["parent"] is not value: cls = instance.__class__.__name__ raise AttributeError( - f"Cannot overwrite parent attribute of {cls}. " - f"To modify {cls}.parent, first set {cls}.parent = None" + f"Cannot overwrite parent attribute of {cls}. " f"To modify {cls}.parent, first set {cls}.parent = None" ) instance.__dict__["parent"] = value @@ -252,10 +249,7 @@ def __init__(self): "Please create a subclass and make it a dataclass." ) else: - raise TypeError( - f"Cannot instantiate {self.__class__.__name__}. " - "Please make it a dataclass." - ) + raise TypeError(f"Cannot instantiate {self.__class__.__name__}. " "Please make it a dataclass.") def _get_attr_names(self) -> List[str]: """Get names of all dataclass attributes of this object. @@ -286,9 +280,7 @@ def get_attr_name(self, attr_val: Any) -> str: return attr_name else: raise AttributeError( - "Could not find name corresponding to attribute.\n" - f"attribute: {attr_val}\n" - f"obj: {self}" + "Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}" ) def _attr_val_is_default(self, attr: str, val: Any) -> bool: @@ -342,22 +334,25 @@ def _val_matches_attr_annotation(cls, attr: str, val: Any) -> bool: return isinstance(val, (list, QuamList)) return type(val) == required_type - def get_reference(self) -> Optional[str]: + def get_reference(self, attr=None) -> Optional[str]: """Get the reference path of this object. + Args: + attr: The attribute to get the reference path for. If None, the reference + path of the object itself is returned. + Returns: The reference path of this object. """ if self.parent is None: - raise AttributeError( - "Unable to extract reference path. Parent must be defined for {self}" - ) - return f"{self.parent.get_reference()}/{self.parent.get_attr_name(self)}" + raise AttributeError("Unable to extract reference path. Parent must be defined for {self}") + reference = f"{self.parent.get_reference()}/{self.parent.get_attr_name(self)}" + if attr is not None: + reference = f"{reference}/{attr}" + return reference - def get_attrs( - self, follow_references: bool = False, include_defaults: bool = True - ) -> Dict[str, Any]: + def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]: """Get all attributes and corresponding values of this object. Args: @@ -381,16 +376,10 @@ def get_attrs( attrs = {attr: getattr(self, attr) for attr in attr_names} if not include_defaults: - attrs = { - attr: val - for attr, val in attrs.items() - if not self._attr_val_is_default(attr, val) - } + attrs = {attr: val for attr, val in attrs.items() if not self._attr_val_is_default(attr, val)} return attrs - def to_dict( - self, follow_references: bool = False, include_defaults: bool = False - ) -> Dict[str, Any]: + def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]: """Convert this object to a dictionary. Args: @@ -407,9 +396,7 @@ def to_dict( `"__class__"` key will be added to the dictionary. This is to ensure that the object can be reconstructed when loading from a file. """ - attrs = self.get_attrs( - follow_references=follow_references, include_defaults=include_defaults - ) + attrs = self.get_attrs(follow_references=follow_references, include_defaults=include_defaults) quam_dict = {} for attr, val in attrs.items(): if isinstance(val, QuamBase): @@ -424,9 +411,7 @@ def to_dict( quam_dict[attr] = val return quam_dict - def iterate_components( - self, skip_elems: bool = None - ) -> Generator["QuamBase", None, None]: + def iterate_components(self, skip_elems: bool = None) -> Generator["QuamBase", None, None]: """Iterate over all QuamBase objects in this object, including nested objects. Args: @@ -488,15 +473,12 @@ def _get_referenced_value(self, reference: str) -> Any: if string_reference.is_absolute_reference(reference) and self._root is None: warnings.warn( - f"No QuamRoot initialized, cannot retrieve reference {reference}" - f" from {self.__class__.__name__}" + f"No QuamRoot initialized, cannot retrieve reference {reference}" f" from {self.__class__.__name__}" ) return reference try: - return string_reference.get_referenced_value( - self, reference, root=self._root - ) + return string_reference.get_referenced_value(self, reference, root=self._root) except ValueError as e: try: ref = f"{self.__class__.__name__}: {self.get_reference()}" @@ -604,9 +586,7 @@ def save( ignore=ignore, ) - def to_dict( - self, follow_references: bool = False, include_defaults: bool = False - ) -> Dict[str, Any]: + def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]: """Convert this object to a dictionary. Args: @@ -770,9 +750,7 @@ def __getitem__(self, i): repr = f"{self.__class__.__name__}: {self.get_reference()}" except Exception: repr = self.__class__.__name__ - raise KeyError( - f"Could not get referenced value {elem} from {repr}" - ) from e + raise KeyError(f"Could not get referenced value {elem} from {repr}") from e return elem # Overriding methods from UserDict @@ -796,9 +774,7 @@ def __repr__(self) -> str: def _get_attr_names(self): return list(self.data.keys()) - def get_attrs( - self, follow_references=False, include_defaults=True - ) -> Dict[str, Any]: + def get_attrs(self, follow_references=False, include_defaults=True) -> Dict[str, Any]: # TODO implement reference kwargs return self.data @@ -819,9 +795,7 @@ def get_attr_name(self, attr_val: Any) -> Union[str, int]: return attr_name else: raise AttributeError( - "Could not find name corresponding to attribute.\n" - f"attribute: {attr_val}\n" - f"obj: {self}" + "Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}" ) def _val_matches_attr_annotation(self, attr: str, val: Any) -> bool: @@ -867,13 +841,10 @@ def get_unreferenced_value(self, attr: str) -> bool: return self.__dict__["data"][attr] except KeyError as e: raise AttributeError( - "Cannot get unreferenced value from attribute {attr} that does not" - " exist in {self}" + "Cannot get unreferenced value from attribute {attr} that does not" " exist in {self}" ) from e - def iterate_components( - self, skip_elems: Sequence[QuamBase] = None - ) -> Generator["QuamBase", None, None]: + def iterate_components(self, skip_elems: Sequence[QuamBase] = None) -> Generator["QuamBase", None, None]: """Iterate over all QuamBase objects in this object, including nested objects. Args: @@ -998,14 +969,10 @@ def get_attr_name(self, attr_val: Any) -> str: return str(k) else: raise AttributeError( - "Could not find name corresponding to attribute" - f"attribute: {attr_val}\n" - f"obj: {self}" + "Could not find name corresponding to attribute" f"attribute: {attr_val}\n" f"obj: {self}" ) - def to_dict( - self, follow_references: bool = False, include_defaults: bool = False - ) -> list: + def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> list: """Convert this object to a list, usually as part of a dictionary representation. Args: @@ -1039,9 +1006,7 @@ def to_dict( quam_list.append(val) return quam_list - def iterate_components( - self, skip_elems: List[QuamBase] = None - ) -> Generator["QuamBase", None, None]: + def iterate_components(self, skip_elems: List[QuamBase] = None) -> Generator["QuamBase", None, None]: """Iterate over all QuamBase objects in this object, including nested objects. Args: @@ -1062,9 +1027,7 @@ def iterate_components( if isinstance(attr_val, QuamBase): yield from attr_val.iterate_components(skip_elems=skip_elems) - def get_attrs( - self, follow_references: bool = False, include_defaults: bool = True - ) -> Dict[str, Any]: + def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]: raise NotImplementedError("QuamList does not have attributes") def print_summary(self, indent: int = 0): diff --git a/tests/quam_base/referencing/test_reference_path.py b/tests/quam_base/referencing/test_reference_path.py index 228ba85f..af6f0a3e 100644 --- a/tests/quam_base/referencing/test_reference_path.py +++ b/tests/quam_base/referencing/test_reference_path.py @@ -52,3 +52,10 @@ def test_quam_list_reference(): assert ( root.quam_elem_list[1].test_str.get_reference() == "#/quam_elem_list/1/test_str" ) + + +def test_get_reference_attr(): + component = QuamComponentTest(test_str="hi") + root = QuamRootTest(quam_elem=component, quam_elem_list=[]) + assert component.get_reference() == "#/quam_elem" + assert component.get_reference("test_str") == "#/quam_elem/test_str"