Skip to content

Commit

Permalink
Feat: Add metadata to QuAM
Browse files Browse the repository at this point in the history
  • Loading branch information
nulinspiratie committed Jul 11, 2024
1 parent 2c79eaa commit 2293569
Showing 1 changed file with 78 additions and 22 deletions.
100 changes: 78 additions & 22 deletions quam/core/quam_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def convert_dict_and_list(value, cls_or_obj=None, attr=None):
return value


def sort_quam_components(components: List["QuamComponent"], max_attempts=5) -> List["QuamComponent"]:
def sort_quam_components(
components: List["QuamComponent"], max_attempts=5
) -> List["QuamComponent"]:
"""Sort QuamComponent objects based on their config_settings.
Args:
Expand Down Expand Up @@ -211,7 +213,8 @@ def __set__(self, instance, value):
if "parent" in instance.__dict__ and instance.__dict__["parent"] is not value:
cls = instance.__class__.__name__
raise AttributeError(
f"Cannot overwrite parent attribute of {cls}. " f"To modify {cls}.parent, first set {cls}.parent = None"
f"Cannot overwrite parent attribute of {cls}. "
f"To modify {cls}.parent, first set {cls}.parent = None"
)
instance.__dict__["parent"] = value

Expand Down Expand Up @@ -249,7 +252,10 @@ def __init__(self):
"Please create a subclass and make it a dataclass."
)
else:
raise TypeError(f"Cannot instantiate {self.__class__.__name__}. " "Please make it a dataclass.")
raise TypeError(
f"Cannot instantiate {self.__class__.__name__}. "
"Please make it a dataclass."
)

def _get_attr_names(self) -> List[str]:
"""Get names of all dataclass attributes of this object.
Expand Down Expand Up @@ -280,7 +286,9 @@ def get_attr_name(self, attr_val: Any) -> str:
return attr_name
else:
raise AttributeError(
"Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}"
"Could not find name corresponding to attribute.\n"
f"attribute: {attr_val}\n"
f"obj: {self}"
)

def _attr_val_is_default(self, attr: str, val: Any) -> bool:
Expand Down Expand Up @@ -334,6 +342,17 @@ def _val_matches_attr_annotation(cls, attr: str, val: Any) -> bool:
return isinstance(val, (list, QuamList))
return type(val) == required_type

def get_metadata(self, attr=None):
if isinstance(self, QuamRoot):
return self._metadata.setdefault("#/", {})

reference = self.get_reference(attr=attr)
if reference is None:
raise AttributeError(
"Unable to extract reference path. Parent must be defined for {self}"
)
return self._root._metadata.setdefault(reference, {})

def get_reference(self, attr=None) -> Optional[str]:
"""Get the reference path of this object.
Expand All @@ -346,13 +365,17 @@ def get_reference(self, attr=None) -> Optional[str]:
"""

if self.parent is None:
raise AttributeError("Unable to extract reference path. Parent must be defined for {self}")
raise AttributeError(
"Unable to extract reference path. Parent must be defined for {self}"
)
reference = f"{self.parent.get_reference()}/{self.parent.get_attr_name(self)}"
if attr is not None:
reference = f"{reference}/{attr}"
return reference

def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]:
def get_attrs(
self, follow_references: bool = False, include_defaults: bool = True
) -> Dict[str, Any]:
"""Get all attributes and corresponding values of this object.
Args:
Expand All @@ -376,10 +399,16 @@ def get_attrs(self, follow_references: bool = False, include_defaults: bool = Tr
attrs = {attr: getattr(self, attr) for attr in attr_names}

if not include_defaults:
attrs = {attr: val for attr, val in attrs.items() if not self._attr_val_is_default(attr, val)}
attrs = {
attr: val
for attr, val in attrs.items()
if not self._attr_val_is_default(attr, val)
}
return attrs

def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]:
def to_dict(
self, follow_references: bool = False, include_defaults: bool = False
) -> Dict[str, Any]:
"""Convert this object to a dictionary.
Args:
Expand All @@ -396,7 +425,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals
`"__class__"` key will be added to the dictionary. This is to ensure
that the object can be reconstructed when loading from a file.
"""
attrs = self.get_attrs(follow_references=follow_references, include_defaults=include_defaults)
attrs = self.get_attrs(
follow_references=follow_references, include_defaults=include_defaults
)
quam_dict = {}
for attr, val in attrs.items():
if isinstance(val, QuamBase):
Expand All @@ -411,7 +442,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals
quam_dict[attr] = val
return quam_dict

def iterate_components(self, skip_elems: bool = None) -> Generator["QuamBase", None, None]:
def iterate_components(
self, skip_elems: bool = None
) -> Generator["QuamBase", None, None]:
"""Iterate over all QuamBase objects in this object, including nested objects.
Args:
Expand Down Expand Up @@ -473,12 +506,15 @@ def _get_referenced_value(self, reference: str) -> Any:

if string_reference.is_absolute_reference(reference) and self._root is None:
warnings.warn(
f"No QuamRoot initialized, cannot retrieve reference {reference}" f" from {self.__class__.__name__}"
f"No QuamRoot initialized, cannot retrieve reference {reference}"
f" from {self.__class__.__name__}"
)
return reference

try:
return string_reference.get_referenced_value(self, reference, root=self._root)
return string_reference.get_referenced_value(
self, reference, root=self._root
)
except ValueError as e:
try:
ref = f"{self.__class__.__name__}: {self.get_reference()}"
Expand Down Expand Up @@ -546,6 +582,7 @@ class QuamRoot(QuamBase):

def __post_init__(self):
QuamBase._root = self
self._metadata: Dict[str, Dict[str, Any]] = {}
super().__post_init__()

def __setattr__(self, name, value):
Expand Down Expand Up @@ -586,7 +623,9 @@ def save(
ignore=ignore,
)

def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]:
def to_dict(
self, follow_references: bool = False, include_defaults: bool = False
) -> Dict[str, Any]:
"""Convert this object to a dictionary.
Args:
Expand Down Expand Up @@ -750,7 +789,9 @@ def __getitem__(self, i):
repr = f"{self.__class__.__name__}: {self.get_reference()}"
except Exception:
repr = self.__class__.__name__
raise KeyError(f"Could not get referenced value {elem} from {repr}") from e
raise KeyError(
f"Could not get referenced value {elem} from {repr}"
) from e
return elem

# Overriding methods from UserDict
Expand All @@ -774,7 +815,9 @@ def __repr__(self) -> str:
def _get_attr_names(self):
return list(self.data.keys())

def get_attrs(self, follow_references=False, include_defaults=True) -> Dict[str, Any]:
def get_attrs(
self, follow_references=False, include_defaults=True
) -> Dict[str, Any]:
# TODO implement reference kwargs
return self.data

Expand All @@ -795,7 +838,9 @@ def get_attr_name(self, attr_val: Any) -> Union[str, int]:
return attr_name
else:
raise AttributeError(
"Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}"
"Could not find name corresponding to attribute.\n"
f"attribute: {attr_val}\n"
f"obj: {self}"
)

def _val_matches_attr_annotation(self, attr: str, val: Any) -> bool:
Expand Down Expand Up @@ -841,10 +886,13 @@ def get_unreferenced_value(self, attr: str) -> bool:
return self.__dict__["data"][attr]
except KeyError as e:
raise AttributeError(
"Cannot get unreferenced value from attribute {attr} that does not" " exist in {self}"
"Cannot get unreferenced value from attribute {attr} that does not"
" exist in {self}"
) from e

def iterate_components(self, skip_elems: Sequence[QuamBase] = None) -> Generator["QuamBase", None, None]:
def iterate_components(
self, skip_elems: Sequence[QuamBase] = None
) -> Generator["QuamBase", None, None]:
"""Iterate over all QuamBase objects in this object, including nested objects.
Args:
Expand Down Expand Up @@ -969,10 +1017,14 @@ def get_attr_name(self, attr_val: Any) -> str:
return str(k)
else:
raise AttributeError(
"Could not find name corresponding to attribute" f"attribute: {attr_val}\n" f"obj: {self}"
"Could not find name corresponding to attribute"
f"attribute: {attr_val}\n"
f"obj: {self}"
)

def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> list:
def to_dict(
self, follow_references: bool = False, include_defaults: bool = False
) -> list:
"""Convert this object to a list, usually as part of a dictionary representation.
Args:
Expand Down Expand Up @@ -1006,7 +1058,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals
quam_list.append(val)
return quam_list

def iterate_components(self, skip_elems: List[QuamBase] = None) -> Generator["QuamBase", None, None]:
def iterate_components(
self, skip_elems: List[QuamBase] = None
) -> Generator["QuamBase", None, None]:
"""Iterate over all QuamBase objects in this object, including nested objects.
Args:
Expand All @@ -1027,7 +1081,9 @@ def iterate_components(self, skip_elems: List[QuamBase] = None) -> Generator["Qu
if isinstance(attr_val, QuamBase):
yield from attr_val.iterate_components(skip_elems=skip_elems)

def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]:
def get_attrs(
self, follow_references: bool = False, include_defaults: bool = True
) -> Dict[str, Any]:
raise NotImplementedError("QuamList does not have attributes")

def print_summary(self, indent: int = 0):
Expand Down

0 comments on commit 2293569

Please sign in to comment.