diff --git a/CHANGELOG.md b/CHANGELOG.md index 8293a1d5..c5b392bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## [Unreleased] +### Added +- Added `WaveformPulse` to allow for pre-defined waveforms. + + ## [0.3.6] ### Changed - Modified `MWChannel` to also have `RF_frequency` and `LO_frequency` to match the signature of `IQChannel`. diff --git a/quam/components/pulses.py b/quam/components/pulses.py index e6f48985..e3b2a005 100644 --- a/quam/components/pulses.py +++ b/quam/components/pulses.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod +from collections.abc import Iterable import numbers import warnings -from typing import Any, ClassVar, Dict, List, Union, Tuple +from typing import Any, ClassVar, Dict, List, Optional, Union, Tuple import numpy as np from quam.core import QuamComponent, quam_dataclass @@ -12,6 +13,7 @@ "Pulse", "BaseReadoutPulse", "ReadoutPulse", + "WaveformPulse", "DragGaussianPulse", "DragCosinePulse", "DragPulse", @@ -398,6 +400,50 @@ def integration_weights_function(self) -> List[Tuple[Union[complex, float], int] } +@quam_dataclass +class WaveformPulse(Pulse): + """Pulse that uses a pre-defined waveform, as opposed to a function. + + For a single channel, only `waveform_I` is required. + For an IQ channel, both `waveform_I` and `waveform_Q` are required. + + The length of the pulse is derived from the length of `waveform_I`. + + Args: + waveform_I (list[float]): The in-phase waveform. + waveform_Q (list[float], optional): The quadrature waveform. + """ + + waveform_I: List[float] # pyright: ignore + waveform_Q: Optional[List[float]] = None + # Length is derived from the waveform_I length, but still needs to be declared + # to satisfy the dataclass, but we'll override its behavior + length: Optional[int] = None # pyright: ignore + + @property + def length(self): # noqa: 811 + if not isinstance(self.waveform_I, Iterable): + return None + return len(self.waveform_I) + + @length.setter + def length(self, length: Optional[int]): + if length is not None and not isinstance(length, property): + raise AttributeError(f"length is not writable with value {length}") + + def waveform_function(self): + if self.waveform_Q is None: + return np.array(self.waveform_I) + return np.array(self.waveform_I) + 1.0j * np.array(self.waveform_Q) + + def to_dict( + self, follow_references: bool = False, include_defaults: bool = False + ) -> Dict[str, Any]: + d = super().to_dict(follow_references, include_defaults) + d.pop("length") + return d + + @quam_dataclass class DragGaussianPulse(Pulse): """Gaussian-based DRAG pulse that compensate for the leakage and AC stark shift. diff --git a/tests/components/pulses/test_waveform_pulse.py b/tests/components/pulses/test_waveform_pulse.py new file mode 100644 index 00000000..a60d5522 --- /dev/null +++ b/tests/components/pulses/test_waveform_pulse.py @@ -0,0 +1,47 @@ +from collections.abc import Iterable +import numpy as np +import pytest +from quam.components.pulses import WaveformPulse + + +def test_waveform_pulse_length(): + pulse = WaveformPulse(waveform_I=[1, 2, 3]) + assert pulse.length == 3 + + pulse.waveform_I = [1, 2, 3, 4] + + with pytest.raises(AttributeError): + pulse.length = 5 + + assert pulse.length == 4 + + +def test_waveform_pulse_IQ(): + pulse = WaveformPulse(waveform_I=[1, 2, 3], waveform_Q=[4, 5, 6]) + assert np.all( + pulse.waveform_function() == np.array([1, 2, 3]) + 1.0j * np.array([4, 5, 6]) + ) + assert pulse.length + + +def test_waveform_pulse_IQ_mismatch(): + pulse = WaveformPulse(waveform_I=[1, 2, 3], waveform_Q=[4, 5]) + with pytest.raises(ValueError): + pulse.waveform_function() + + +def test_waveform_pulse_to_dict(): + pulse = WaveformPulse(waveform_I=[1, 2, 3], waveform_Q=[4, 5, 6]) + assert pulse.to_dict() == { + "waveform_I": [1, 2, 3], + "waveform_Q": [4, 5, 6], + } + + +def test_waveform_pulse_length_error(): + with pytest.raises(AttributeError): + pulse = WaveformPulse(waveform_I=[1, 2, 3], length=11) + + pulse = WaveformPulse(waveform_I=[1, 2, 3]) + with pytest.raises(AttributeError): + pulse.length = 11