Skip to content

Commit

Permalink
Merge pull request #35 from qua-platform/fix/quam_dataclass_default_v…
Browse files Browse the repository at this point in the history
…alue

Fix/quam dataclass default value
  • Loading branch information
nulinspiratie authored May 7, 2024
2 parents df832ff + 292f15c commit 8ffd218
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 13 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
## [Unreleased]
## [0.3.2]
### Added
- Added full QuAM documentation, including web hosting

### Fixed
- Fix error where a numpy array of integration weights raises an error
- Fix instantiation of a dictionary where the value is a reference
- Fix optional parameters of a quam component parent class were sometimes categorized as a required parameter (ReadoutPulse)

## [0.3.1]
### Added
Expand Down
28 changes: 16 additions & 12 deletions quam/utils/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class REQUIRED:


def get_dataclass_attr_annotations(
cls_or_obj: Union[type, object]
cls_or_obj: Union[type, object],
) -> Dict[str, Dict[str, type]]:
"""Get the attributes and annotations of a dataclass
Expand Down Expand Up @@ -80,13 +80,6 @@ def dataclass_field_has_default(field: dataclasses.field) -> bool:
return False


def dataclass_has_default_fields(cls) -> bool:
"""Check if dataclass has any default fields"""
fields = dataclasses.fields(cls)
fields_default = any(dataclass_field_has_default(field) for field in fields)
return fields_default


def handle_inherited_required_fields(cls):
"""Adds a default REQUIRED flag for dataclass fields when necessary
Expand All @@ -95,13 +88,24 @@ def handle_inherited_required_fields(cls):
if not is_dataclass(cls):
return

# Check if dataclass has default fields
fields_required = dataclass_has_default_fields(cls)
if not fields_required:
# Check if dataclass has fields with default value
optional_fields = [
field.name
for field in dataclasses.fields(cls)
if dataclass_field_has_default(field)
]
if not optional_fields:
# All fields of the dataclass are required, we don't have to handle situations
# where the parent class has fields with default values and the subclass has
# required fields.
return

# Check if class (not parents) has required fields
required_attrs = [attr for attr in cls.__annotations__ if attr not in cls.__dict__]
required_attrs = [
attr
for attr in cls.__annotations__
if attr not in cls.__dict__ and attr not in optional_fields
]
for attr in required_attrs:
setattr(cls, attr, REQUIRED)

Expand Down
9 changes: 9 additions & 0 deletions tests/components/pulses/test_pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from quam.core import *
from quam.components import *
from quam.components.channels import Channel, IQChannel, SingleChannel
from quam.utils.dataclass import get_dataclass_attr_annotations


def test_drag_pulse():
Expand Down Expand Up @@ -183,3 +184,11 @@ def test_pulses_referenced():
machine_loaded.channel.operations.get_unreferenced_value("pulse_referenced")
== "#./pulse"
)


def test_pulse_attr_annotations():
from quam.components import pulses

attr_annotations = get_dataclass_attr_annotations(pulses.SquareReadoutPulse)

assert list(attr_annotations["required"]) == ["length", "amplitude"]
18 changes: 18 additions & 0 deletions tests/quam_base/test_quam_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,21 @@ class C:
assert len(f) == 2
assert f[0].name == "int_val"
assert f[1].name == "int_val_optional"


def test_quam_dataclass_optional_field():
from quam.core import QuamComponent

@quam_dataclass
class RootClass(QuamComponent):
optional_root_attr: int = None

@quam_dataclass
class DerivedClass(RootClass):
pass

from quam.utils.dataclass import get_dataclass_attr_annotations

attr_annotations = get_dataclass_attr_annotations(DerivedClass)

assert list(attr_annotations["required"]) == []

0 comments on commit 8ffd218

Please sign in to comment.