Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/quam dataclass default value #35

Merged
merged 2 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
## [Unreleased]
## [0.3.2]
### Added
- Added full QuAM documentation, including web hosting

### Fixed
- Fix error where a numpy array of integration weights raises an error
- Fix instantiation of a dictionary where the value is a reference
- Fix optional parameters of a quam component parent class were sometimes categorized as a required parameter (ReadoutPulse)

## [0.3.1]
### Added
Expand Down
28 changes: 16 additions & 12 deletions quam/utils/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class REQUIRED:


def get_dataclass_attr_annotations(
cls_or_obj: Union[type, object]
cls_or_obj: Union[type, object],
) -> Dict[str, Dict[str, type]]:
"""Get the attributes and annotations of a dataclass

Expand Down Expand Up @@ -80,13 +80,6 @@ def dataclass_field_has_default(field: dataclasses.field) -> bool:
return False


def dataclass_has_default_fields(cls) -> bool:
"""Check if dataclass has any default fields"""
fields = dataclasses.fields(cls)
fields_default = any(dataclass_field_has_default(field) for field in fields)
return fields_default


def handle_inherited_required_fields(cls):
"""Adds a default REQUIRED flag for dataclass fields when necessary

Expand All @@ -95,13 +88,24 @@ def handle_inherited_required_fields(cls):
if not is_dataclass(cls):
return

# Check if dataclass has default fields
fields_required = dataclass_has_default_fields(cls)
if not fields_required:
# Check if dataclass has fields with default value
optional_fields = [
field.name
for field in dataclasses.fields(cls)
if dataclass_field_has_default(field)
]
if not optional_fields:
# All fields of the dataclass are required, we don't have to handle situations
# where the parent class has fields with default values and the subclass has
# required fields.
return

# Check if class (not parents) has required fields
required_attrs = [attr for attr in cls.__annotations__ if attr not in cls.__dict__]
required_attrs = [
attr
for attr in cls.__annotations__
if attr not in cls.__dict__ and attr not in optional_fields
]
for attr in required_attrs:
setattr(cls, attr, REQUIRED)

Expand Down
9 changes: 9 additions & 0 deletions tests/components/pulses/test_pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from quam.core import *
from quam.components import *
from quam.components.channels import Channel, IQChannel, SingleChannel
from quam.utils.dataclass import get_dataclass_attr_annotations


def test_drag_pulse():
Expand Down Expand Up @@ -183,3 +184,11 @@ def test_pulses_referenced():
machine_loaded.channel.operations.get_unreferenced_value("pulse_referenced")
== "#./pulse"
)


def test_pulse_attr_annotations():
from quam.components import pulses

attr_annotations = get_dataclass_attr_annotations(pulses.SquareReadoutPulse)

assert list(attr_annotations["required"]) == ["length", "amplitude"]
18 changes: 18 additions & 0 deletions tests/quam_base/test_quam_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,21 @@ class C:
assert len(f) == 2
assert f[0].name == "int_val"
assert f[1].name == "int_val_optional"


def test_quam_dataclass_optional_field():
from quam.core import QuamComponent

@quam_dataclass
class RootClass(QuamComponent):
optional_root_attr: int = None

@quam_dataclass
class DerivedClass(RootClass):
pass

from quam.utils.dataclass import get_dataclass_attr_annotations

attr_annotations = get_dataclass_attr_annotations(DerivedClass)

assert list(attr_annotations["required"]) == []