diff --git a/CHANGELOG.md b/CHANGELOG.md index 261d2ee6..49dd78c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,11 @@ -## [Unreleased] +## [0.3.2] +### Added +- Added full QuAM documentation, including web hosting + ### Fixed - Fix error where a numpy array of integration weights raises an error - Fix instantiation of a dictionary where the value is a reference +- Fix optional parameters of a quam component parent class were sometimes categorized as a required parameter (ReadoutPulse) ## [0.3.1] ### Added diff --git a/quam/utils/dataclass.py b/quam/utils/dataclass.py index 37f29553..2fe954a9 100644 --- a/quam/utils/dataclass.py +++ b/quam/utils/dataclass.py @@ -16,7 +16,7 @@ class REQUIRED: def get_dataclass_attr_annotations( - cls_or_obj: Union[type, object] + cls_or_obj: Union[type, object], ) -> Dict[str, Dict[str, type]]: """Get the attributes and annotations of a dataclass @@ -80,13 +80,6 @@ def dataclass_field_has_default(field: dataclasses.field) -> bool: return False -def dataclass_has_default_fields(cls) -> bool: - """Check if dataclass has any default fields""" - fields = dataclasses.fields(cls) - fields_default = any(dataclass_field_has_default(field) for field in fields) - return fields_default - - def handle_inherited_required_fields(cls): """Adds a default REQUIRED flag for dataclass fields when necessary @@ -95,13 +88,24 @@ def handle_inherited_required_fields(cls): if not is_dataclass(cls): return - # Check if dataclass has default fields - fields_required = dataclass_has_default_fields(cls) - if not fields_required: + # Check if dataclass has fields with default value + optional_fields = [ + field.name + for field in dataclasses.fields(cls) + if dataclass_field_has_default(field) + ] + if not optional_fields: + # All fields of the dataclass are required, we don't have to handle situations + # where the parent class has fields with default values and the subclass has + # required fields. return # Check if class (not parents) has required fields - required_attrs = [attr for attr in cls.__annotations__ if attr not in cls.__dict__] + required_attrs = [ + attr + for attr in cls.__annotations__ + if attr not in cls.__dict__ and attr not in optional_fields + ] for attr in required_attrs: setattr(cls, attr, REQUIRED) diff --git a/tests/components/pulses/test_pulses.py b/tests/components/pulses/test_pulses.py index b0ba6319..b525f1e3 100644 --- a/tests/components/pulses/test_pulses.py +++ b/tests/components/pulses/test_pulses.py @@ -4,6 +4,7 @@ from quam.core import * from quam.components import * from quam.components.channels import Channel, IQChannel, SingleChannel +from quam.utils.dataclass import get_dataclass_attr_annotations def test_drag_pulse(): @@ -183,3 +184,11 @@ def test_pulses_referenced(): machine_loaded.channel.operations.get_unreferenced_value("pulse_referenced") == "#./pulse" ) + + +def test_pulse_attr_annotations(): + from quam.components import pulses + + attr_annotations = get_dataclass_attr_annotations(pulses.SquareReadoutPulse) + + assert list(attr_annotations["required"]) == ["length", "amplitude"] diff --git a/tests/quam_base/test_quam_dataclass.py b/tests/quam_base/test_quam_dataclass.py index 756975f6..ec3a2d61 100644 --- a/tests/quam_base/test_quam_dataclass.py +++ b/tests/quam_base/test_quam_dataclass.py @@ -207,3 +207,21 @@ class C: assert len(f) == 2 assert f[0].name == "int_val" assert f[1].name == "int_val_optional" + + +def test_quam_dataclass_optional_field(): + from quam.core import QuamComponent + + @quam_dataclass + class RootClass(QuamComponent): + optional_root_attr: int = None + + @quam_dataclass + class DerivedClass(RootClass): + pass + + from quam.utils.dataclass import get_dataclass_attr_annotations + + attr_annotations = get_dataclass_attr_annotations(DerivedClass) + + assert list(attr_annotations["required"]) == []