From da81c48ce371f2a40a23a8ae9c0e867c422311f8 Mon Sep 17 00:00:00 2001 From: Serwan Asaad Date: Fri, 25 Oct 2024 09:52:45 +0200 Subject: [PATCH 1/3] Feat: Add WaveformPulse --- CHANGELOG.md | 5 +++ quam/components/pulses.py | 44 ++++++++++++++++++- .../components/pulses/test_waveform_pulse.py | 36 +++++++++++++++ 3 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 tests/components/pulses/test_waveform_pulse.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8293a1d5..c5b392bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## [Unreleased] +### Added +- Added `WaveformPulse` to allow for pre-defined waveforms. + + ## [0.3.6] ### Changed - Modified `MWChannel` to also have `RF_frequency` and `LO_frequency` to match the signature of `IQChannel`. diff --git a/quam/components/pulses.py b/quam/components/pulses.py index e6f48985..1987bc31 100644 --- a/quam/components/pulses.py +++ b/quam/components/pulses.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod import numbers import warnings -from typing import Any, ClassVar, Dict, List, Union, Tuple +from typing import Any, ClassVar, Dict, List, Optional, Union, Tuple import numpy as np from quam.core import QuamComponent, quam_dataclass @@ -398,6 +398,48 @@ def integration_weights_function(self) -> List[Tuple[Union[complex, float], int] } +@quam_dataclass +class WaveformPulse(Pulse): + """Pulse that uses a pre-defined waveform, as opposed to a function. + + For a single channel, only `waveform_I` is required. + For an IQ channel, both `waveform_I` and `waveform_Q` are required. + + The length of the pulse is derived from the length of `waveform_I`. + + Args: + waveform_I (list[float]): The in-phase waveform. + waveform_Q (list[float], optional): The quadrature waveform. + """ + + waveform_I: List[float] # pyright: ignore + waveform_Q: Optional[List[float]] = None + # Length is derived from the waveform_I length, but still needs to be declared + # to satisfy the dataclass, but we'll override its behavior + length: Optional[int] = None # pyright: ignore + + @property + def length(self): # noqa: 811 + return len(self.waveform_I) + + @length.setter + def length(self, length: Optional[int]): + if length is not None and not isinstance(length, property): + raise AttributeError(f"length is not writable with value {length}") + + def waveform_function(self): + if self.waveform_Q is None: + return np.array(self.waveform_I) + return np.array(self.waveform_I) + 1.0j * np.array(self.waveform_Q) + + def to_dict( + self, follow_references: bool = False, include_defaults: bool = False + ) -> Dict[str, Any]: + d = super().to_dict(follow_references, include_defaults) + d.pop("length") + return d + + @quam_dataclass class DragGaussianPulse(Pulse): """Gaussian-based DRAG pulse that compensate for the leakage and AC stark shift. diff --git a/tests/components/pulses/test_waveform_pulse.py b/tests/components/pulses/test_waveform_pulse.py new file mode 100644 index 00000000..4642afff --- /dev/null +++ b/tests/components/pulses/test_waveform_pulse.py @@ -0,0 +1,36 @@ +import numpy as np +import pytest +from quam.components.pulses import WaveformPulse + + +def test_waveform_pulse_length(): + pulse = WaveformPulse(waveform_I=[1, 2, 3]) + assert pulse.length == 3 + + pulse.waveform_I = [1, 2, 3, 4] + + with pytest.raises(AttributeError): + pulse.length = 5 + + assert pulse.length == 4 + + +def test_waveform_pulse_IQ(): + pulse = WaveformPulse(waveform_I=[1, 2, 3], waveform_Q=[4, 5, 6]) + assert np.all( + pulse.waveform_function() == np.array([1, 2, 3]) + 1.0j * np.array([4, 5, 6]) + ) + + +def test_waveform_pulse_IQ_mismatch(): + pulse = WaveformPulse(waveform_I=[1, 2, 3], waveform_Q=[4, 5]) + with pytest.raises(ValueError): + pulse.waveform_function() + + +def test_waveform_pulse_to_dict(): + pulse = WaveformPulse(waveform_I=[1, 2, 3], waveform_Q=[4, 5, 6]) + assert pulse.to_dict() == { + "waveform_I": [1, 2, 3], + "waveform_Q": [4, 5, 6], + } From 7d85052b79ede020676122549e38674360906144 Mon Sep 17 00:00:00 2001 From: Serwan Asaad Date: Fri, 25 Oct 2024 14:48:24 +0200 Subject: [PATCH 2/3] Add WaveFormPulse to pulses.__all__ --- quam/components/pulses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/quam/components/pulses.py b/quam/components/pulses.py index 1987bc31..d6f9609b 100644 --- a/quam/components/pulses.py +++ b/quam/components/pulses.py @@ -12,6 +12,7 @@ "Pulse", "BaseReadoutPulse", "ReadoutPulse", + "WaveformPulse", "DragGaussianPulse", "DragCosinePulse", "DragPulse", From e9412b1f4a83f89baa43442baa19fb633040cdaf Mon Sep 17 00:00:00 2001 From: Serwan Asaad Date: Mon, 28 Oct 2024 12:49:33 +0100 Subject: [PATCH 3/3] Fix length issue for python < 3.10 --- quam/components/pulses.py | 3 +++ tests/components/pulses/test_waveform_pulse.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/quam/components/pulses.py b/quam/components/pulses.py index d6f9609b..e3b2a005 100644 --- a/quam/components/pulses.py +++ b/quam/components/pulses.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Iterable import numbers import warnings from typing import Any, ClassVar, Dict, List, Optional, Union, Tuple @@ -421,6 +422,8 @@ class WaveformPulse(Pulse): @property def length(self): # noqa: 811 + if not isinstance(self.waveform_I, Iterable): + return None return len(self.waveform_I) @length.setter diff --git a/tests/components/pulses/test_waveform_pulse.py b/tests/components/pulses/test_waveform_pulse.py index 4642afff..a60d5522 100644 --- a/tests/components/pulses/test_waveform_pulse.py +++ b/tests/components/pulses/test_waveform_pulse.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable import numpy as np import pytest from quam.components.pulses import WaveformPulse @@ -20,6 +21,7 @@ def test_waveform_pulse_IQ(): assert np.all( pulse.waveform_function() == np.array([1, 2, 3]) + 1.0j * np.array([4, 5, 6]) ) + assert pulse.length def test_waveform_pulse_IQ_mismatch(): @@ -34,3 +36,12 @@ def test_waveform_pulse_to_dict(): "waveform_I": [1, 2, 3], "waveform_Q": [4, 5, 6], } + + +def test_waveform_pulse_length_error(): + with pytest.raises(AttributeError): + pulse = WaveformPulse(waveform_I=[1, 2, 3], length=11) + + pulse = WaveformPulse(waveform_I=[1, 2, 3]) + with pytest.raises(AttributeError): + pulse.length = 11