Skip to content

Commit

Permalink
Merge pull request #47 from qua-platform/fix/instantiate-union-pipe
Browse files Browse the repository at this point in the history
feat: Add support for instantiating union types
  • Loading branch information
nulinspiratie authored Jul 3, 2024
2 parents 249dc2d + 9e3b0ac commit 5e55155
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
- Allow `QuamBase.get_reference(attr)` to return a reference of one of its attributes

### Fixed
- Fix quam object instantiation error when a parameter type uses pipe operator
- Allow int keys to be serialised / loaded in QuAM using JSONSerialiser


## [0.3.3]
### Added
- Added the following parameters to `IQChannel`: `RF_frequency`, `LO_frequency`, `intermediate_frequency`
Expand Down
10 changes: 9 additions & 1 deletion quam/core/quam_instantiation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations
import sys
import types
import typing
from typing import TYPE_CHECKING, Dict, Any
from inspect import isclass
Expand All @@ -16,6 +18,12 @@
from quam.core import QuamBase


if sys.version_info < (3, 10):
union_types = (typing.Union,)
else:
union_types = [typing.Union, types.UnionType]


def instantiate_attrs_from_dict(
attr_dict: dict,
required_type: type,
Expand Down Expand Up @@ -224,7 +232,7 @@ def instantiate_attr(
)
if typing.get_origin(expected_type) == dict:
expected_type = dict
elif typing.get_origin(expected_type) == typing.Union:
elif typing.get_origin(expected_type) in union_types:
for union_type in typing.get_args(expected_type):
try:
instantiated_attr = instantiate_attr(
Expand Down
19 changes: 19 additions & 0 deletions tests/instantiation/test_instantiation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
from typing import List, Literal, Optional, Tuple, Union

from pytest_cov.engine import sys

from quam.core import QuamRoot, QuamComponent, quam_dataclass
from quam.core.quam_classes import QuamDict
from quam.examples.superconducting_qubits.components import Transmon
Expand Down Expand Up @@ -339,6 +341,7 @@ def test_instantiate_dict_referenced():

assert attrs == {"test_attr": "#./reference"}


@quam_dataclass
class TestQuamComponent(QuamComponent):
a: int
Expand All @@ -357,3 +360,19 @@ class TestQuamUnion(QuamComponent):

with pytest.raises(TypeError):
instantiate_quam_class(TestQuamUnion, {"union_val": {"a": "42"}})


@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
def test_instantiation_pipe_union_type():
@quam_dataclass
class TestQuamUnion(QuamComponent):
union_val: int | TestQuamComponent

obj = instantiate_quam_class(TestQuamUnion, {"union_val": 42})
assert obj.union_val == 42

obj = instantiate_quam_class(TestQuamUnion, {"union_val": {"a": 42}})
assert obj.union_val.a == 42

with pytest.raises(TypeError):
instantiate_quam_class(TestQuamUnion, {"union_val": {"a": "42"}})

0 comments on commit 5e55155

Please sign in to comment.