diff --git a/quam/components/quantum_components/qubit.py b/quam/components/quantum_components/qubit.py index 8fca4e3..424d228 100644 --- a/quam/components/quantum_components/qubit.py +++ b/quam/components/quantum_components/qubit.py @@ -1,6 +1,6 @@ from collections import UserDict from collections.abc import Iterable -from typing import Dict, Union, TYPE_CHECKING, Any +from typing import Dict, List, Optional, Union, TYPE_CHECKING, Any from dataclasses import field from qm import qua @@ -77,11 +77,29 @@ def get_pulse(self, pulse_name: str) -> Pulse: else: return pulses[0] - def align(self, *other_qubits: "Qubit"): + @QuantumComponent.register_macro + def align( + self, + other_qubits: Optional[Union["Qubit", Iterable["Qubit"]]] = None, + *args: "Qubit", + ): """Aligns the execution of all channels of this qubit and all other qubits""" - channel_names = [channel.name for channel in self.channels.values()] - for qubit in other_qubits: - channel_names.extend([channel.name for channel in qubit.channels.values()]) + quantum_components = [self] + + if isinstance(other_qubits, Qubit): + quantum_components.append(other_qubits) + elif isinstance(other_qubits, Iterable): + quantum_components.extend(other_qubits) + elif other_qubits is not None: + raise ValueError(f"Invalid type for other_qubits: {type(other_qubits)}") + + if args: + assert all(isinstance(arg, Qubit) for arg in args) + quantum_components.extend(args) + + channel_names = { + ch.name for qubit in quantum_components for ch in qubit.channels.values() + } align(*channel_names) diff --git a/tests/components/quantum_components/test_qubit.py b/tests/components/quantum_components/test_qubit.py index 270fd05..e6704ef 100644 --- a/tests/components/quantum_components/test_qubit.py +++ b/tests/components/quantum_components/test_qubit.py @@ -63,7 +63,7 @@ def test_qubit_align(mock_qubit_with_resonator, mock_qubit, mocker): from quam.components.quantum_components.qubit import align - align.assert_called_once_with("q1.xy", "q1.resonator", "q0.xy") + align.assert_called_once_with(*{"q1.xy", "q1.resonator", "q0.xy"}) def test_qubit_get_macros(mock_qubit): @@ -77,7 +77,7 @@ def test_qubit_apply_align(mock_qubit_with_resonator, mocker): from quam.components.quantum_components.qubit import align - align.assert_called_once_with("q1.xy", "q1.resonator") + align.assert_called_once_with(*{"q1.xy", "q1.resonator"}) def test_qubit_inferred_id_direct():