From 7a4113b9f1f0c3d2d11edc75618fa7c7c15785ce Mon Sep 17 00:00:00 2001 From: Tom Peham Date: Mon, 1 Jul 2024 13:42:47 +0200 Subject: [PATCH] Added tests for CSS Code class. --- src/mqt/qecc/__init__.py | 6 +- src/mqt/qecc/code.py | 118 +++++++++++++++++++++------------ test/python/test_code.py | 138 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 218 insertions(+), 44 deletions(-) create mode 100644 test/python/test_code.py diff --git a/src/mqt/qecc/__init__.py b/src/mqt/qecc/__init__.py index a236ec30..bda36cf3 100644 --- a/src/mqt/qecc/__init__.py +++ b/src/mqt/qecc/__init__.py @@ -9,8 +9,7 @@ from ._version import version as __version__ from .analog_information_decoding.simulators.analog_tannergraph_decoding import AnalogTannergraphDecoder, AtdSimulator from .analog_information_decoding.simulators.quasi_single_shot_v2 import QssSimulator -from .code import CSSCode - +from .code import CSSCode, InvalidCSSCodeError from .pyqecc import ( Code, Decoder, @@ -27,12 +26,14 @@ __all__ = [ "AnalogTannergraphDecoder", "AtdSimulator", + "CSSCode", "Code", "Decoder", "DecodingResult", "DecodingResultStatus", "DecodingRunInformation", "GrowthVariant", + "InvalidCSSCodeError", # "SoftInfoDecoder", "QssSimulator", "UFDecoder", @@ -40,5 +41,4 @@ "__version__", "apply_ecc", "sample_iid_pauli_err", - "CSSCode", ] diff --git a/src/mqt/qecc/code.py b/src/mqt/qecc/code.py index 625709d1..d56a006f 100644 --- a/src/mqt/qecc/code.py +++ b/src/mqt/qecc/code.py @@ -19,44 +19,85 @@ class CSSCode: def __init__( self, distance: int, - Hx: npt.NDArray[np.int8], - Hz: npt.NDArray[np.int8], + Hx: npt.NDArray[np.int8] | None = None, # noqa: N803 + Hz: npt.NDArray[np.int8] | None = None, # noqa: N803 x_distance: int | None = None, z_distance: int | None = None, - ) -> None: # noqa: N803 + ) -> None: """Initialize the code.""" self.distance = distance self.x_distance = x_distance if x_distance is not None else distance self.z_distance = z_distance if z_distance is not None else distance - assert self.distance <= min( - self.x_distance, self.z_distance - ), "The distance must be less than or equal to the x and z distances" - assert Hx.shape[1] == Hz.shape[1], "Hx and Hz must have the same number of columns" + if self.distance < 0: + msg = "The distance must be a non-negative integer" + raise InvalidCSSCodeError(msg) + if Hx is None and Hz is None: + msg = "At least one of the check matrices must be provided" + raise InvalidCSSCodeError(msg) + if self.x_distance < self.distance or self.z_distance < self.distance: + msg = "The x and z distances must be greater than or equal to the distance" + raise InvalidCSSCodeError(msg) + if Hx is not None and Hz is not None: + if Hx.shape[1] != Hz.shape[1]: + msg = "Check matrices must have the same number of columns" + raise InvalidCSSCodeError(msg) + if np.any(Hx @ Hz.T % 2 != 0): + msg = "The check matrices must be orthogonal" + raise InvalidCSSCodeError(msg) self.Hx = Hx self.Hz = Hz - self.n = Hx.shape[1] - self.k = self.n - Hx.shape[0] - Hz.shape[0] + self.n = Hx.shape[1] if Hx is not None else Hz.shape[1] # type: ignore[union-attr] + self.k = self.n - (Hx.shape[0] if Hx is not None else 0) - (Hz.shape[0] if Hz is not None else 0) self.Lx = CSSCode._compute_logical(self.Hz, self.Hx) self.Lz = CSSCode._compute_logical(self.Hx, self.Hz) def __hash__(self) -> int: """Compute a hash for the CSS code.""" - return hash(int.from_bytes(self.Hx.tobytes(), sys.byteorder) ^ int.from_bytes(self.Hz.tobytes(), sys.byteorder)) + x_hash = int.from_bytes(self.Hx.tobytes(), sys.byteorder) if self.Hx is not None else 0 + z_hash = int.from_bytes(self.Hz.tobytes(), sys.byteorder) if self.Hz is not None else 0 + return hash(x_hash ^ z_hash) def __eq__(self, other: object) -> bool: """Check if two CSS codes are equal.""" if not isinstance(other, CSSCode): return NotImplemented + if self.Hx is None and other.Hx is None: + assert self.Hz is not None + assert other.Hz is not None + return np.array_equal(self.Hz, other.Hz) + if self.Hz is None and other.Hz is None: + assert self.Hx is not None + assert other.Hx is not None + return np.array_equal(self.Hx, other.Hx) + if (self.Hx is None and other.Hx is not None) or (self.Hx is not None and other.Hx is None): + return False + if (self.Hz is None and other.Hz is not None) or (self.Hz is not None and other.Hz is None): + return False + assert self.Hx is not None + assert other.Hx is not None + assert self.Hz is not None + assert other.Hz is not None return bool( mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, other.Hx])) and mod2.rank(self.Hz) == mod2.rank(np.vstack([self.Hz, other.Hz])) ) @staticmethod - def _compute_logical(m1: npt.NDArray[np.int8], m2: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]: + def _compute_logical(m1: npt.NDArray[np.int8] | None, m2: npt.NDArray[np.int8] | None) -> npt.NDArray[np.int8]: """Compute the logical matrix L.""" + if m1 is None: + ker_m2 = mod2.nullspace(m2) # compute the kernel basis of m2 + pivots = mod2.row_echelon(ker_m2)[-1] + logs = np.zeros_like(ker_m2, dtype=np.int8) # type: npt.NDArray[np.int8] + for i, pivot in enumerate(pivots): + logs[i, pivot] = 1 + return logs + + if m2 is None: + return mod2.nullspace(m1).astype(np.int8) # type: ignore[no-any-return] + ker_m1 = mod2.nullspace(m1) # compute the kernel basis of m1 im_m2_transp = mod2.row_basis(m2) # compute the image basis of m2 log_stack = np.vstack([im_m2_transp, ker_m1]) @@ -66,43 +107,55 @@ def _compute_logical(m1: npt.NDArray[np.int8], m2: npt.NDArray[np.int8]) -> npt. def get_x_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]: """Compute the x syndrome of the error.""" + if self.Hx is None: + return np.empty((0, error.shape[0]), dtype=np.int8) return self.Hx @ error % 2 def get_z_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]: """Compute the z syndrome of the error.""" + if self.Hz is None: + return np.empty((0, error.shape[0]), dtype=np.int8) return self.Hz @ error % 2 def check_if_logical_x_error(self, residual: npt.NDArray[np.int8]) -> bool: """Check if the residual is a logical error.""" - return (self.Lz @ residual % 2).any() is True + return bool((self.Lz @ residual % 2 == 1).any()) def check_if_logical_z_error(self, residual: npt.NDArray[np.int8]) -> bool: """Check if the residual is a logical error.""" - return (self.Lx @ residual % 2).any() is True + return bool((self.Lx @ residual % 2 == 1).any()) def stabilizer_eq_x_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool: """Check if two X errors are in the same coset.""" + if self.Hx is None: + return bool(np.array_equal(error_1, error_2)) m1 = np.vstack([self.Hx, error_1]) m2 = np.vstack([self.Hx, error_2]) m3 = np.vstack([self.Hx, error_1, error_2]) - return mod2.rank(m1) == mod2.rank(m2) == mod2.rank(m3) + return bool(mod2.rank(m1) == mod2.rank(m2) == mod2.rank(m3)) def stabilizer_eq_z_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool: """Check if two Z errors are in the same coset.""" + if self.Hz is None: + return bool(np.array_equal(error_1, error_2)) m1 = np.vstack([self.Hz, error_1]) m2 = np.vstack([self.Hz, error_2]) m3 = np.vstack([self.Hz, error_1, error_2]) - return mod2.rank(m1) == mod2.rank(m2) == mod2.rank(m3) + return bool(mod2.rank(m1) == mod2.rank(m2) == mod2.rank(m3)) def is_self_dual(self) -> bool: """Check if the code is self-dual.""" - return self.Hx.shape[0] == self.Hz.shape[0] and mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, self.Hz])) + if self.Hx is None or self.Hz is None: + return False + return bool( + self.Hx.shape[0] == self.Hz.shape[0] and mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, self.Hz])) + ) - def stabs_as_pauli_strings(self) -> tuple[list[str], list[str]]: + def stabs_as_pauli_strings(self) -> tuple[list[str] | None, list[str] | None]: """Return the stabilizers as Pauli strings.""" - return ["".join(["I" if x == 0 else "X" for x in row]) for row in self.Hx], [ - "".join(["I" if z == 0 else "Z" for z in row]) for row in self.Hz - ] + x_str = None if self.Hx is None else ["".join(["I" if x == 0 else "X" for x in row]) for row in self.Hx] + z_str = None if self.Hz is None else ["".join(["I" if z == 0 else "Z" for z in row]) for row in self.Hz] + return x_str, z_str def z_logicals_as_pauli_string(self) -> str: """Return the logical Z operator as a Pauli string.""" @@ -170,28 +223,11 @@ def from_code_name(code_name: str, distance: int | None = None) -> CSSCode: distance = min(x_distance, z_distance) elif distance is None: msg = f"Distance is not specified for {code_name}" - raise ValueError(msg) + raise InvalidCSSCodeError(msg) return CSSCode(distance, hx, hz, x_distance=x_distance, z_distance=z_distance) msg = f"Unknown code name: {code_name}" - raise ValueError(msg) - + raise InvalidCSSCodeError(msg) -class ClassicalCode: - """A class for representing classical codes.""" - def __init__(self, distance: int, H: npt.NDArray[np.int8]) -> None: # noqa: N803 - """Initialize the code.""" - self.distance = distance - self.H = H - self.n = H.shape[1] - self.k = self.n - H.shape[0] - - -class HyperGraphProductCode(CSSCode): - """A class for representing hypergraph product codes.""" - - def __init__(self, c1: ClassicalCode, c2: ClassicalCode) -> None: - """Initialize the code.""" - Hx = np.hstack((np.kron(c1.H.T, np.eye(c2.H.shape[0])), np.kron(np.eye(c1.n), c2.H))) # noqa: N806 - Hz = np.hstack((np.kron(np.eye(c1.H.shape[0]), c2.H.T), np.kron(c1.H, np.eye(c2.n)))) # noqa: N806 - super().__init__(np.min(c1.distance, c2.distance), Hx, Hz) +class InvalidCSSCodeError(ValueError): + """Raised when the CSS code is invalid.""" diff --git a/test/python/test_code.py b/test/python/test_code.py new file mode 100644 index 00000000..3a48d473 --- /dev/null +++ b/test/python/test_code.py @@ -0,0 +1,138 @@ +"""Test the CSSCode class.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from mqt.qecc import CSSCode, InvalidCSSCodeError + +if TYPE_CHECKING: # pragma: no cover + import numpy.typing as npt + + +@pytest.fixture() +def rep_code() -> tuple[npt.NDArray[np.int8] | None, npt.NDArray[np.int8] | None]: + """Return the parity check matrices for the repetition code.""" + hx = np.array([[1, 1, 0], [0, 0, 1]]) + hz = None + return hx, hz + + +@pytest.fixture() +def steane_code() -> tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]: + """Return the check matrices for the Steane code.""" + hx = np.array([[1, 1, 1, 1, 0, 0, 0], [1, 0, 1, 0, 1, 0, 1], [0, 1, 1, 0, 1, 1, 0]]) + hz = hx + return hx, hz + + +def test_invalid_css_codes() -> None: + """Test that an invalid CSS code raises an error.""" + # Violates CSS condition + hx = np.array([[1, 1, 1]]) + hz = np.array([[1, 0, 0]]) + with pytest.raises(InvalidCSSCodeError): + CSSCode(distance=3, Hx=hx, Hz=hz) + + # Distances don't match + hz = np.array([[1, 1, 0]]) + with pytest.raises(InvalidCSSCodeError): + CSSCode(distance=3, Hx=hx, Hz=hz, x_distance=4, z_distance=1) + + # Checks not over the same number of qubits + hz = np.array([[1, 1]]) + with pytest.raises(InvalidCSSCodeError): + CSSCode(distance=3, Hx=hx, Hz=hz) + + # Invalid distance + with pytest.raises(InvalidCSSCodeError): + CSSCode(distance=-1, Hx=hx) + + # Checks not provided + with pytest.raises(InvalidCSSCodeError): + CSSCode(distance=3) + + +@pytest.mark.parametrize("checks", ["steane_code", "rep_code"]) +def test_logicals(checks: tuple[npt.NDArray[np.int8] | None, npt.NDArray[np.int8] | None], request) -> None: # type: ignore[no-untyped-def] + """Test the logical operators of the CSSCode class.""" + hx, hz = request.getfixturevalue(checks) + code = CSSCode(distance=3, Hx=hx, Hz=hz) + assert code.Lx is not None + assert code.Lz is not None + assert code.Lx.shape[1] == code.Lz.shape[1] == hx.shape[1] + assert code.Lx.shape[0] == code.Lz.shape[0] + + # assert that logicals anticommute + assert code.Lx @ code.Lz.T % 2 != 0 + + # assert that logicals commute with stabilizers + if code.Hz is not None: + assert np.all(code.Lx @ code.Hz.T % 2 == 0) + if code.Hx is not None: + assert np.all(code.Lz @ code.Hx.T % 2 == 0) + + +def test_errors(steane_code: tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]) -> None: + """Test error detection and symdromes.""" + hx, hz = steane_code + code = CSSCode(distance=3, Hx=hx, Hz=hz) + e1 = np.array([1, 0, 0, 0, 0, 0, 0]) + e2 = np.array([0, 1, 0, 0, 1, 0, 0]) + e3 = np.array([0, 0, 0, 0, 0, 1, 1]) + e4 = np.array([0, 1, 1, 1, 0, 0, 0]) + + assert np.array_equal(code.get_x_syndrome(e1), code.get_z_syndrome(e2)) + assert np.array_equal(code.get_x_syndrome(e2), code.get_z_syndrome(e2)) + + x_syndrome_1 = code.get_x_syndrome(e1) + x_syndrome_2 = code.get_x_syndrome(e2) + x_syndrome_3 = code.get_x_syndrome(e3) + x_syndrome_4 = code.get_x_syndrome(e4) + + assert np.array_equal(x_syndrome_1, x_syndrome_2) + assert not np.array_equal(x_syndrome_1, x_syndrome_3) + assert np.array_equal(x_syndrome_1, x_syndrome_4) + + # e1 and e2 have same syndrome but if we add them we get a logical error + assert code.check_if_logical_x_error((e1 + e2) % 2) + assert code.check_if_logical_z_error((e1 + e2) % 2) + assert not code.stabilizer_eq_x_error(e1, e2) + assert not code.stabilizer_eq_z_error(e1, e2) + + # e1 and e4 on the other hand do not induce a logical error because they are stabilizer equivalent + assert not code.check_if_logical_x_error((e1 + e4) % 2) + assert not code.check_if_logical_z_error((e1 + e4) % 2) + assert code.stabilizer_eq_x_error(e1, e4) + assert code.stabilizer_eq_z_error(e1, e4) + + +def test_steane(steane_code: tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]) -> None: + """Test utility functions and correctness of the Steane code.""" + hx, hz = steane_code + code = CSSCode(distance=3, Hx=hx, Hz=hz) + assert code.n == 7 + assert code.k == 1 + assert code.distance == 3 + assert code.is_self_dual() + + x_paulis, z_paulis = code.stabs_as_pauli_strings() + assert x_paulis is not None + assert z_paulis is not None + assert len(x_paulis) == len(z_paulis) == 3 + assert x_paulis == ["XXXXIII", "XIXIXIX", "IXXIXXI"] + assert z_paulis == ["ZZZZIII", "ZIZIZIZ", "IZZIZZI"] + + x_log = code.x_logicals_as_pauli_string() + z_log = code.z_logicals_as_pauli_string() + assert x_log.count("X") == 3 + assert x_log.count("I") == 4 + assert z_log.count("Z") == 3 + assert z_log.count("I") == 4 + + hx_reordered = hx[::-1, :] + code_reordered = CSSCode(distance=3, Hx=hx_reordered, Hz=hz) + assert code == code_reordered