diff --git a/CHANGELOG.md b/CHANGELOG.md index dc3496ae..12186a34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,10 @@ - Allow `QuamBase.get_reference(attr)` to return a reference of one of its attributes ### Fixed +- Fix quam object instantiation error when a parameter type uses pipe operator - Allow int keys to be serialised / loaded in QuAM using JSONSerialiser + ## [0.3.3] ### Added - Added the following parameters to `IQChannel`: `RF_frequency`, `LO_frequency`, `intermediate_frequency` diff --git a/quam/core/quam_instantiation.py b/quam/core/quam_instantiation.py index 57680cf4..e1bd0720 100644 --- a/quam/core/quam_instantiation.py +++ b/quam/core/quam_instantiation.py @@ -1,4 +1,6 @@ from __future__ import annotations +import sys +import types import typing from typing import TYPE_CHECKING, Dict, Any from inspect import isclass @@ -16,6 +18,12 @@ from quam.core import QuamBase +if sys.version_info < (3, 10): + union_types = (typing.Union,) +else: + union_types = [typing.Union, types.UnionType] + + def instantiate_attrs_from_dict( attr_dict: dict, required_type: type, @@ -224,7 +232,7 @@ def instantiate_attr( ) if typing.get_origin(expected_type) == dict: expected_type = dict - elif typing.get_origin(expected_type) == typing.Union: + elif typing.get_origin(expected_type) in union_types: for union_type in typing.get_args(expected_type): try: instantiated_attr = instantiate_attr( diff --git a/tests/instantiation/test_instantiation.py b/tests/instantiation/test_instantiation.py index 7aa293ff..f6c0da63 100644 --- a/tests/instantiation/test_instantiation.py +++ b/tests/instantiation/test_instantiation.py @@ -1,6 +1,8 @@ import pytest from typing import List, Literal, Optional, Tuple, Union +from pytest_cov.engine import sys + from quam.core import QuamRoot, QuamComponent, quam_dataclass from quam.core.quam_classes import QuamDict from quam.examples.superconducting_qubits.components import Transmon @@ -339,6 +341,7 @@ def test_instantiate_dict_referenced(): assert attrs == {"test_attr": "#./reference"} + @quam_dataclass class TestQuamComponent(QuamComponent): a: int @@ -357,3 +360,19 @@ class TestQuamUnion(QuamComponent): with pytest.raises(TypeError): instantiate_quam_class(TestQuamUnion, {"union_val": {"a": "42"}}) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +def test_instantiation_pipe_union_type(): + @quam_dataclass + class TestQuamUnion(QuamComponent): + union_val: int | TestQuamComponent + + obj = instantiate_quam_class(TestQuamUnion, {"union_val": 42}) + assert obj.union_val == 42 + + obj = instantiate_quam_class(TestQuamUnion, {"union_val": {"a": 42}}) + assert obj.union_val.a == 42 + + with pytest.raises(TypeError): + instantiate_quam_class(TestQuamUnion, {"union_val": {"a": "42"}})