Skip to content

Commit

Permalink
Added tests for CSS Code class.
Browse files Browse the repository at this point in the history
  • Loading branch information
pehamTom committed Jul 1, 2024
1 parent 3a8e8e8 commit 7a4113b
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 44 deletions.
6 changes: 3 additions & 3 deletions src/mqt/qecc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from ._version import version as __version__
from .analog_information_decoding.simulators.analog_tannergraph_decoding import AnalogTannergraphDecoder, AtdSimulator
from .analog_information_decoding.simulators.quasi_single_shot_v2 import QssSimulator
from .code import CSSCode

from .code import CSSCode, InvalidCSSCodeError
from .pyqecc import (
Code,
Decoder,
Expand All @@ -27,18 +26,19 @@
__all__ = [
"AnalogTannergraphDecoder",
"AtdSimulator",
"CSSCode",
"Code",
"Decoder",
"DecodingResult",
"DecodingResultStatus",
"DecodingRunInformation",
"GrowthVariant",
"InvalidCSSCodeError",
# "SoftInfoDecoder",
"QssSimulator",
"UFDecoder",
"UFHeuristic",
"__version__",
"apply_ecc",
"sample_iid_pauli_err",
"CSSCode",
]
118 changes: 77 additions & 41 deletions src/mqt/qecc/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,85 @@ class CSSCode:
def __init__(
self,
distance: int,
Hx: npt.NDArray[np.int8],
Hz: npt.NDArray[np.int8],
Hx: npt.NDArray[np.int8] | None = None, # noqa: N803
Hz: npt.NDArray[np.int8] | None = None, # noqa: N803
x_distance: int | None = None,
z_distance: int | None = None,
) -> None: # noqa: N803
) -> None:
"""Initialize the code."""
self.distance = distance
self.x_distance = x_distance if x_distance is not None else distance
self.z_distance = z_distance if z_distance is not None else distance

assert self.distance <= min(
self.x_distance, self.z_distance
), "The distance must be less than or equal to the x and z distances"
assert Hx.shape[1] == Hz.shape[1], "Hx and Hz must have the same number of columns"
if self.distance < 0:
msg = "The distance must be a non-negative integer"
raise InvalidCSSCodeError(msg)
if Hx is None and Hz is None:
msg = "At least one of the check matrices must be provided"
raise InvalidCSSCodeError(msg)
if self.x_distance < self.distance or self.z_distance < self.distance:
msg = "The x and z distances must be greater than or equal to the distance"
raise InvalidCSSCodeError(msg)
if Hx is not None and Hz is not None:
if Hx.shape[1] != Hz.shape[1]:
msg = "Check matrices must have the same number of columns"
raise InvalidCSSCodeError(msg)
if np.any(Hx @ Hz.T % 2 != 0):
msg = "The check matrices must be orthogonal"
raise InvalidCSSCodeError(msg)

self.Hx = Hx
self.Hz = Hz
self.n = Hx.shape[1]
self.k = self.n - Hx.shape[0] - Hz.shape[0]
self.n = Hx.shape[1] if Hx is not None else Hz.shape[1] # type: ignore[union-attr]
self.k = self.n - (Hx.shape[0] if Hx is not None else 0) - (Hz.shape[0] if Hz is not None else 0)
self.Lx = CSSCode._compute_logical(self.Hz, self.Hx)
self.Lz = CSSCode._compute_logical(self.Hx, self.Hz)

def __hash__(self) -> int:
"""Compute a hash for the CSS code."""
return hash(int.from_bytes(self.Hx.tobytes(), sys.byteorder) ^ int.from_bytes(self.Hz.tobytes(), sys.byteorder))
x_hash = int.from_bytes(self.Hx.tobytes(), sys.byteorder) if self.Hx is not None else 0
z_hash = int.from_bytes(self.Hz.tobytes(), sys.byteorder) if self.Hz is not None else 0
return hash(x_hash ^ z_hash)

def __eq__(self, other: object) -> bool:
"""Check if two CSS codes are equal."""
if not isinstance(other, CSSCode):
return NotImplemented
if self.Hx is None and other.Hx is None:
assert self.Hz is not None
assert other.Hz is not None
return np.array_equal(self.Hz, other.Hz)
if self.Hz is None and other.Hz is None:
assert self.Hx is not None
assert other.Hx is not None
return np.array_equal(self.Hx, other.Hx)
if (self.Hx is None and other.Hx is not None) or (self.Hx is not None and other.Hx is None):
return False
if (self.Hz is None and other.Hz is not None) or (self.Hz is not None and other.Hz is None):
return False
assert self.Hx is not None
assert other.Hx is not None
assert self.Hz is not None
assert other.Hz is not None
return bool(
mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, other.Hx]))
and mod2.rank(self.Hz) == mod2.rank(np.vstack([self.Hz, other.Hz]))
)

@staticmethod
def _compute_logical(m1: npt.NDArray[np.int8], m2: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
def _compute_logical(m1: npt.NDArray[np.int8] | None, m2: npt.NDArray[np.int8] | None) -> npt.NDArray[np.int8]:
"""Compute the logical matrix L."""
if m1 is None:
ker_m2 = mod2.nullspace(m2) # compute the kernel basis of m2
pivots = mod2.row_echelon(ker_m2)[-1]
logs = np.zeros_like(ker_m2, dtype=np.int8) # type: npt.NDArray[np.int8]
for i, pivot in enumerate(pivots):
logs[i, pivot] = 1
return logs

if m2 is None:
return mod2.nullspace(m1).astype(np.int8) # type: ignore[no-any-return]

ker_m1 = mod2.nullspace(m1) # compute the kernel basis of m1
im_m2_transp = mod2.row_basis(m2) # compute the image basis of m2
log_stack = np.vstack([im_m2_transp, ker_m1])
Expand All @@ -66,43 +107,55 @@ def _compute_logical(m1: npt.NDArray[np.int8], m2: npt.NDArray[np.int8]) -> npt.

def get_x_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
"""Compute the x syndrome of the error."""
if self.Hx is None:
return np.empty((0, error.shape[0]), dtype=np.int8)
return self.Hx @ error % 2

def get_z_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
"""Compute the z syndrome of the error."""
if self.Hz is None:
return np.empty((0, error.shape[0]), dtype=np.int8)
return self.Hz @ error % 2

def check_if_logical_x_error(self, residual: npt.NDArray[np.int8]) -> bool:
"""Check if the residual is a logical error."""
return (self.Lz @ residual % 2).any() is True
return bool((self.Lz @ residual % 2 == 1).any())

def check_if_logical_z_error(self, residual: npt.NDArray[np.int8]) -> bool:
"""Check if the residual is a logical error."""
return (self.Lx @ residual % 2).any() is True
return bool((self.Lx @ residual % 2 == 1).any())

def stabilizer_eq_x_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool:
"""Check if two X errors are in the same coset."""
if self.Hx is None:
return bool(np.array_equal(error_1, error_2))
m1 = np.vstack([self.Hx, error_1])
m2 = np.vstack([self.Hx, error_2])
m3 = np.vstack([self.Hx, error_1, error_2])
return mod2.rank(m1) == mod2.rank(m2) == mod2.rank(m3)
return bool(mod2.rank(m1) == mod2.rank(m2) == mod2.rank(m3))

def stabilizer_eq_z_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool:
"""Check if two Z errors are in the same coset."""
if self.Hz is None:
return bool(np.array_equal(error_1, error_2))
m1 = np.vstack([self.Hz, error_1])
m2 = np.vstack([self.Hz, error_2])
m3 = np.vstack([self.Hz, error_1, error_2])
return mod2.rank(m1) == mod2.rank(m2) == mod2.rank(m3)
return bool(mod2.rank(m1) == mod2.rank(m2) == mod2.rank(m3))

def is_self_dual(self) -> bool:
"""Check if the code is self-dual."""
return self.Hx.shape[0] == self.Hz.shape[0] and mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, self.Hz]))
if self.Hx is None or self.Hz is None:
return False
return bool(
self.Hx.shape[0] == self.Hz.shape[0] and mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, self.Hz]))
)

def stabs_as_pauli_strings(self) -> tuple[list[str], list[str]]:
def stabs_as_pauli_strings(self) -> tuple[list[str] | None, list[str] | None]:
"""Return the stabilizers as Pauli strings."""
return ["".join(["I" if x == 0 else "X" for x in row]) for row in self.Hx], [
"".join(["I" if z == 0 else "Z" for z in row]) for row in self.Hz
]
x_str = None if self.Hx is None else ["".join(["I" if x == 0 else "X" for x in row]) for row in self.Hx]
z_str = None if self.Hz is None else ["".join(["I" if z == 0 else "Z" for z in row]) for row in self.Hz]
return x_str, z_str

def z_logicals_as_pauli_string(self) -> str:
"""Return the logical Z operator as a Pauli string."""
Expand Down Expand Up @@ -170,28 +223,11 @@ def from_code_name(code_name: str, distance: int | None = None) -> CSSCode:
distance = min(x_distance, z_distance)
elif distance is None:
msg = f"Distance is not specified for {code_name}"
raise ValueError(msg)
raise InvalidCSSCodeError(msg)
return CSSCode(distance, hx, hz, x_distance=x_distance, z_distance=z_distance)
msg = f"Unknown code name: {code_name}"
raise ValueError(msg)

raise InvalidCSSCodeError(msg)

class ClassicalCode:
"""A class for representing classical codes."""

def __init__(self, distance: int, H: npt.NDArray[np.int8]) -> None: # noqa: N803
"""Initialize the code."""
self.distance = distance
self.H = H
self.n = H.shape[1]
self.k = self.n - H.shape[0]


class HyperGraphProductCode(CSSCode):
"""A class for representing hypergraph product codes."""

def __init__(self, c1: ClassicalCode, c2: ClassicalCode) -> None:
"""Initialize the code."""
Hx = np.hstack((np.kron(c1.H.T, np.eye(c2.H.shape[0])), np.kron(np.eye(c1.n), c2.H))) # noqa: N806
Hz = np.hstack((np.kron(np.eye(c1.H.shape[0]), c2.H.T), np.kron(c1.H, np.eye(c2.n)))) # noqa: N806
super().__init__(np.min(c1.distance, c2.distance), Hx, Hz)
class InvalidCSSCodeError(ValueError):
"""Raised when the CSS code is invalid."""
138 changes: 138 additions & 0 deletions test/python/test_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Test the CSSCode class."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import pytest

from mqt.qecc import CSSCode, InvalidCSSCodeError

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt


@pytest.fixture()
def rep_code() -> tuple[npt.NDArray[np.int8] | None, npt.NDArray[np.int8] | None]:
"""Return the parity check matrices for the repetition code."""
hx = np.array([[1, 1, 0], [0, 0, 1]])
hz = None
return hx, hz


@pytest.fixture()
def steane_code() -> tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]:
"""Return the check matrices for the Steane code."""
hx = np.array([[1, 1, 1, 1, 0, 0, 0], [1, 0, 1, 0, 1, 0, 1], [0, 1, 1, 0, 1, 1, 0]])
hz = hx
return hx, hz


def test_invalid_css_codes() -> None:
"""Test that an invalid CSS code raises an error."""
# Violates CSS condition
hx = np.array([[1, 1, 1]])
hz = np.array([[1, 0, 0]])
with pytest.raises(InvalidCSSCodeError):
CSSCode(distance=3, Hx=hx, Hz=hz)

# Distances don't match
hz = np.array([[1, 1, 0]])
with pytest.raises(InvalidCSSCodeError):
CSSCode(distance=3, Hx=hx, Hz=hz, x_distance=4, z_distance=1)

# Checks not over the same number of qubits
hz = np.array([[1, 1]])
with pytest.raises(InvalidCSSCodeError):
CSSCode(distance=3, Hx=hx, Hz=hz)

# Invalid distance
with pytest.raises(InvalidCSSCodeError):
CSSCode(distance=-1, Hx=hx)

# Checks not provided
with pytest.raises(InvalidCSSCodeError):
CSSCode(distance=3)


@pytest.mark.parametrize("checks", ["steane_code", "rep_code"])
def test_logicals(checks: tuple[npt.NDArray[np.int8] | None, npt.NDArray[np.int8] | None], request) -> None: # type: ignore[no-untyped-def]
"""Test the logical operators of the CSSCode class."""
hx, hz = request.getfixturevalue(checks)
code = CSSCode(distance=3, Hx=hx, Hz=hz)
assert code.Lx is not None
assert code.Lz is not None
assert code.Lx.shape[1] == code.Lz.shape[1] == hx.shape[1]
assert code.Lx.shape[0] == code.Lz.shape[0]

# assert that logicals anticommute
assert code.Lx @ code.Lz.T % 2 != 0

# assert that logicals commute with stabilizers
if code.Hz is not None:
assert np.all(code.Lx @ code.Hz.T % 2 == 0)
if code.Hx is not None:
assert np.all(code.Lz @ code.Hx.T % 2 == 0)


def test_errors(steane_code: tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]) -> None:
"""Test error detection and symdromes."""
hx, hz = steane_code
code = CSSCode(distance=3, Hx=hx, Hz=hz)
e1 = np.array([1, 0, 0, 0, 0, 0, 0])
e2 = np.array([0, 1, 0, 0, 1, 0, 0])
e3 = np.array([0, 0, 0, 0, 0, 1, 1])
e4 = np.array([0, 1, 1, 1, 0, 0, 0])

assert np.array_equal(code.get_x_syndrome(e1), code.get_z_syndrome(e2))
assert np.array_equal(code.get_x_syndrome(e2), code.get_z_syndrome(e2))

x_syndrome_1 = code.get_x_syndrome(e1)
x_syndrome_2 = code.get_x_syndrome(e2)
x_syndrome_3 = code.get_x_syndrome(e3)
x_syndrome_4 = code.get_x_syndrome(e4)

assert np.array_equal(x_syndrome_1, x_syndrome_2)
assert not np.array_equal(x_syndrome_1, x_syndrome_3)
assert np.array_equal(x_syndrome_1, x_syndrome_4)

# e1 and e2 have same syndrome but if we add them we get a logical error
assert code.check_if_logical_x_error((e1 + e2) % 2)
assert code.check_if_logical_z_error((e1 + e2) % 2)
assert not code.stabilizer_eq_x_error(e1, e2)
assert not code.stabilizer_eq_z_error(e1, e2)

# e1 and e4 on the other hand do not induce a logical error because they are stabilizer equivalent
assert not code.check_if_logical_x_error((e1 + e4) % 2)
assert not code.check_if_logical_z_error((e1 + e4) % 2)
assert code.stabilizer_eq_x_error(e1, e4)
assert code.stabilizer_eq_z_error(e1, e4)


def test_steane(steane_code: tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]) -> None:
"""Test utility functions and correctness of the Steane code."""
hx, hz = steane_code
code = CSSCode(distance=3, Hx=hx, Hz=hz)
assert code.n == 7
assert code.k == 1
assert code.distance == 3
assert code.is_self_dual()

x_paulis, z_paulis = code.stabs_as_pauli_strings()
assert x_paulis is not None
assert z_paulis is not None
assert len(x_paulis) == len(z_paulis) == 3
assert x_paulis == ["XXXXIII", "XIXIXIX", "IXXIXXI"]
assert z_paulis == ["ZZZZIII", "ZIZIZIZ", "IZZIZZI"]

x_log = code.x_logicals_as_pauli_string()
z_log = code.z_logicals_as_pauli_string()
assert x_log.count("X") == 3
assert x_log.count("I") == 4
assert z_log.count("Z") == 3
assert z_log.count("I") == 4

hx_reordered = hx[::-1, :]
code_reordered = CSSCode(distance=3, Hx=hx_reordered, Hz=hz)
assert code == code_reordered

0 comments on commit 7a4113b

Please sign in to comment.