Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/Register state updates #57

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 79 additions & 23 deletions quam/core/quam_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
generate_config_final_actions,
)
from quam.core.quam_instantiation import instantiate_quam_class
from quam.utils.state_tracker import StateTracker
from .qua_config_template import qua_config_template


Expand Down Expand Up @@ -81,7 +82,9 @@ def convert_dict_and_list(value, cls_or_obj=None, attr=None):
return value


def sort_quam_components(components: List["QuamComponent"], max_attempts=5) -> List["QuamComponent"]:
def sort_quam_components(
components: List["QuamComponent"], max_attempts=5
) -> List["QuamComponent"]:
"""Sort QuamComponent objects based on their config_settings.

Args:
Expand Down Expand Up @@ -211,7 +214,8 @@ def __set__(self, instance, value):
if "parent" in instance.__dict__ and instance.__dict__["parent"] is not value:
cls = instance.__class__.__name__
raise AttributeError(
f"Cannot overwrite parent attribute of {cls}. " f"To modify {cls}.parent, first set {cls}.parent = None"
f"Cannot overwrite parent attribute of {cls}. "
f"To modify {cls}.parent, first set {cls}.parent = None"
)
instance.__dict__["parent"] = value

Expand Down Expand Up @@ -249,7 +253,10 @@ def __init__(self):
"Please create a subclass and make it a dataclass."
)
else:
raise TypeError(f"Cannot instantiate {self.__class__.__name__}. " "Please make it a dataclass.")
raise TypeError(
f"Cannot instantiate {self.__class__.__name__}. "
"Please make it a dataclass."
)

def _get_attr_names(self) -> List[str]:
"""Get names of all dataclass attributes of this object.
Expand Down Expand Up @@ -280,7 +287,9 @@ def get_attr_name(self, attr_val: Any) -> str:
return attr_name
else:
raise AttributeError(
"Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}"
"Could not find name corresponding to attribute.\n"
f"attribute: {attr_val}\n"
f"obj: {self}"
)

def _attr_val_is_default(self, attr: str, val: Any) -> bool:
Expand Down Expand Up @@ -346,13 +355,17 @@ def get_reference(self, attr=None) -> Optional[str]:
"""

if self.parent is None:
raise AttributeError("Unable to extract reference path. Parent must be defined for {self}")
raise AttributeError(
"Unable to extract reference path. Parent must be defined for {self}"
)
reference = f"{self.parent.get_reference()}/{self.parent.get_attr_name(self)}"
if attr is not None:
reference = f"{reference}/{attr}"
return reference

def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]:
def get_attrs(
self, follow_references: bool = False, include_defaults: bool = True
) -> Dict[str, Any]:
"""Get all attributes and corresponding values of this object.

Args:
Expand All @@ -376,10 +389,16 @@ def get_attrs(self, follow_references: bool = False, include_defaults: bool = Tr
attrs = {attr: getattr(self, attr) for attr in attr_names}

if not include_defaults:
attrs = {attr: val for attr, val in attrs.items() if not self._attr_val_is_default(attr, val)}
attrs = {
attr: val
for attr, val in attrs.items()
if not self._attr_val_is_default(attr, val)
}
return attrs

def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]:
def to_dict(
self, follow_references: bool = False, include_defaults: bool = False
) -> Dict[str, Any]:
"""Convert this object to a dictionary.

Args:
Expand All @@ -396,7 +415,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals
`"__class__"` key will be added to the dictionary. This is to ensure
that the object can be reconstructed when loading from a file.
"""
attrs = self.get_attrs(follow_references=follow_references, include_defaults=include_defaults)
attrs = self.get_attrs(
follow_references=follow_references, include_defaults=include_defaults
)
quam_dict = {}
for attr, val in attrs.items():
if isinstance(val, QuamBase):
Expand All @@ -411,7 +432,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals
quam_dict[attr] = val
return quam_dict

def iterate_components(self, skip_elems: bool = None) -> Generator["QuamBase", None, None]:
def iterate_components(
self, skip_elems: bool = None
) -> Generator["QuamBase", None, None]:
"""Iterate over all QuamBase objects in this object, including nested objects.

Args:
Expand Down Expand Up @@ -473,12 +496,15 @@ def _get_referenced_value(self, reference: str) -> Any:

if string_reference.is_absolute_reference(reference) and self._root is None:
warnings.warn(
f"No QuamRoot initialized, cannot retrieve reference {reference}" f" from {self.__class__.__name__}"
f"No QuamRoot initialized, cannot retrieve reference {reference}"
f" from {self.__class__.__name__}"
)
return reference

try:
return string_reference.get_referenced_value(self, reference, root=self._root)
return string_reference.get_referenced_value(
self, reference, root=self._root
)
except ValueError as e:
try:
ref = f"{self.__class__.__name__}: {self.get_reference()}"
Expand Down Expand Up @@ -546,6 +572,7 @@ class QuamRoot(QuamBase):

def __post_init__(self):
QuamBase._root = self
self._state_tracker = StateTracker()
super().__post_init__()

def __setattr__(self, name, value):
Expand Down Expand Up @@ -585,8 +612,11 @@ def save(
include_defaults=include_defaults,
ignore=ignore,
)
self._state_tracker.update_state(serialiser.contents)

def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]:
def to_dict(
self, follow_references: bool = False, include_defaults: bool = False
) -> Dict[str, Any]:
"""Convert this object to a dictionary.

Args:
Expand Down Expand Up @@ -626,13 +656,16 @@ def load(
serialiser = cls.serialiser()
contents, _ = serialiser.load(filepath_or_dict)

return instantiate_quam_class(
quam_obj = instantiate_quam_class(
quam_class=cls,
contents=contents,
fix_attrs=fix_attrs,
validate_type=validate_type,
)

quam_obj._state_tracker.update_.state = contents
return quam_obj

def generate_config(self) -> Dict[str, Any]:
"""Generate the QUA configuration from the QuAM object.

Expand All @@ -658,6 +691,9 @@ def generate_config(self) -> Dict[str, Any]:
def get_unreferenced_value(self, attr: str):
return getattr(self, attr)

def print_state_changes(self):
self._state_tracker.print_state_changes(self.to_dict(), mode="save")


class QuamComponent(QuamBase):
"""Base class for any QuAM component class.
Expand Down Expand Up @@ -693,6 +729,9 @@ def apply_to_config(self, config: dict) -> None:
"""
...

# def print_state_changes(self):
# self._root.print_state_changes(self.to_dict(), mode="update")


@quam_dataclass
class QuamDict(UserDict, QuamBase):
Expand Down Expand Up @@ -750,7 +789,9 @@ def __getitem__(self, i):
repr = f"{self.__class__.__name__}: {self.get_reference()}"
except Exception:
repr = self.__class__.__name__
raise KeyError(f"Could not get referenced value {elem} from {repr}") from e
raise KeyError(
f"Could not get referenced value {elem} from {repr}"
) from e
return elem

# Overriding methods from UserDict
Expand All @@ -774,7 +815,9 @@ def __repr__(self) -> str:
def _get_attr_names(self):
return list(self.data.keys())

def get_attrs(self, follow_references=False, include_defaults=True) -> Dict[str, Any]:
def get_attrs(
self, follow_references=False, include_defaults=True
) -> Dict[str, Any]:
# TODO implement reference kwargs
return self.data

Expand All @@ -795,7 +838,9 @@ def get_attr_name(self, attr_val: Any) -> Union[str, int]:
return attr_name
else:
raise AttributeError(
"Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}"
"Could not find name corresponding to attribute.\n"
f"attribute: {attr_val}\n"
f"obj: {self}"
)

def _val_matches_attr_annotation(self, attr: str, val: Any) -> bool:
Expand Down Expand Up @@ -841,10 +886,13 @@ def get_unreferenced_value(self, attr: str) -> bool:
return self.__dict__["data"][attr]
except KeyError as e:
raise AttributeError(
"Cannot get unreferenced value from attribute {attr} that does not" " exist in {self}"
"Cannot get unreferenced value from attribute {attr} that does not"
" exist in {self}"
) from e

def iterate_components(self, skip_elems: Sequence[QuamBase] = None) -> Generator["QuamBase", None, None]:
def iterate_components(
self, skip_elems: Sequence[QuamBase] = None
) -> Generator["QuamBase", None, None]:
"""Iterate over all QuamBase objects in this object, including nested objects.

Args:
Expand Down Expand Up @@ -969,10 +1017,14 @@ def get_attr_name(self, attr_val: Any) -> str:
return str(k)
else:
raise AttributeError(
"Could not find name corresponding to attribute" f"attribute: {attr_val}\n" f"obj: {self}"
"Could not find name corresponding to attribute"
f"attribute: {attr_val}\n"
f"obj: {self}"
)

def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> list:
def to_dict(
self, follow_references: bool = False, include_defaults: bool = False
) -> list:
"""Convert this object to a list, usually as part of a dictionary representation.

Args:
Expand Down Expand Up @@ -1006,7 +1058,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals
quam_list.append(val)
return quam_list

def iterate_components(self, skip_elems: List[QuamBase] = None) -> Generator["QuamBase", None, None]:
def iterate_components(
self, skip_elems: List[QuamBase] = None
) -> Generator["QuamBase", None, None]:
"""Iterate over all QuamBase objects in this object, including nested objects.

Args:
Expand All @@ -1027,7 +1081,9 @@ def iterate_components(self, skip_elems: List[QuamBase] = None) -> Generator["Qu
if isinstance(attr_val, QuamBase):
yield from attr_val.iterate_components(skip_elems=skip_elems)

def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]:
def get_attrs(
self, follow_references: bool = False, include_defaults: bool = True
) -> Dict[str, Any]:
raise NotImplementedError("QuamList does not have attributes")

def print_summary(self, indent: int = 0):
Expand Down
4 changes: 4 additions & 0 deletions quam/serialisation/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class JSONSerialiser(AbstractSerialiser):
default_filename = "state.json"
default_foldername = "quam"
content_mapping = {}
contents = {}

def _save_dict_to_json(self, contents: Dict[str, Any], path: Path):
"""Save a dictionary to a JSON file.
Expand Down Expand Up @@ -104,6 +105,7 @@ def save(
content_mapping = content_mapping.copy()

contents = quam_obj.to_dict(include_defaults=include_defaults)
self.contents = contents.copy()

# TODO This should ideally go to the QuamRoot.to_dict method
for key in ignore or []:
Expand Down Expand Up @@ -182,6 +184,8 @@ def load(
else:
metadata["content_mapping"][file.name] = list(file_contents.keys())

self.contents = contents

return contents, metadata


Expand Down
Loading
Loading