Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend functionality for working with codes. #288

Merged
merged 20 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/mqt/qecc/codes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from .bb_codes import construct_bb_code
from .color_code import ColorCode, LatticeType
from .concatenation import ConcatenatedCode, ConcatenatedCSSCode
from .constructions import construct_iceberg_code, construct_many_hypercube_code, construct_quantum_hamming_code
from .css_code import CSSCode, InvalidCSSCodeError
from .hexagonal_color_code import HexagonalColorCode
from .square_octagon_color_code import SquareOctagonColorCode
Expand All @@ -12,11 +14,16 @@
__all__ = [
"CSSCode",
"ColorCode",
"ConcatenatedCSSCode",
"ConcatenatedCode",
"HexagonalColorCode",
"InvalidCSSCodeError",
"InvalidStabilizerCodeError",
"LatticeType",
"SquareOctagonColorCode",
"StabilizerCode",
"construct_bb_code",
"construct_iceberg_code",
"construct_many_hypercube_code",
"construct_quantum_hamming_code",
]
152 changes: 152 additions & 0 deletions src/mqt/qecc/codes/concatenation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Concatenated quantum codes."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from .css_code import CSSCode
from .pauli import Pauli
from .stabilizer_code import InvalidStabilizerCodeError, StabilizerCode
from .symplectic import SymplecticVector

if TYPE_CHECKING:
from collections.abc import Sequence

import numpy.typing as npt


class ConcatenatedCode(StabilizerCode):
Fixed Show fixed Hide fixed
"""A concatenated quantum code."""

def __init__(self, outer_code: StabilizerCode, inner_code: StabilizerCode | Sequence[StabilizerCode]) -> None:
"""Initialize a concatenated quantum code.

Args:
outer_code: The outer code.
inner_code: The inner code. If a list of codes is provided, the qubits of the outer code are encoded by the different inner codes in the list.
"""
self.outer_code = outer_code
Fixed Show fixed Hide fixed
if isinstance(inner_code, list):
self.inner_codes = inner_code
Fixed Show fixed Hide fixed
else:
self.inner_codes = [inner_code] * outer_code.n
Fixed Show fixed Hide fixed
if not all(code.k == 1 for code in self.inner_codes):
msg = "The inner codes must be stabilizer codes with a single logical qubit."
raise InvalidStabilizerCodeError(msg)

self.n = sum(code.n for code in self.inner_codes)
Fixed Show fixed Hide fixed
generators = [self._outer_pauli_to_physical(p) for p in outer_code.generators]

x_logicals = None
z_logicals = None
if outer_code.x_logicals is not None:
x_logicals = [self._outer_pauli_to_physical(p) for p in outer_code.x_logicals]
if outer_code.z_logicals is not None:
z_logicals = [self._outer_pauli_to_physical(p) for p in outer_code.z_logicals]

d = min(code.distance * outer_code.distance for code in self.inner_codes)
StabilizerCode.__init__(self, generators, d, x_logicals, z_logicals)

def __eq__(self, other: object) -> bool:
"""Check if two concatenated codes are equal."""
if not isinstance(other, ConcatenatedCode):
return NotImplemented
return self.outer_code == other.outer_code and all(
c1 == c2 for c1, c2 in zip(self.inner_codes, other.inner_codes)
)

def __hash__(self) -> int:
"""Compute the hash of the concatenated code."""
return hash((self.outer_code, tuple(self.inner_codes)))

def _outer_pauli_to_physical(self, p: Pauli) -> Pauli:
"""Convert a Pauli operator on the outer code to the operator on the concatenated code.

Args:
p: The Pauli operator.

Returns:
The Pauli operator on the physical qubits.
"""
if len(p) != self.outer_code.n:
msg = "The Pauli operator must have the same number of qubits as the outer code."
raise InvalidStabilizerCodeError(msg)
concatenated = SymplecticVector.zeros(self.n)
phase = 0
offset = 0
for i in range(self.outer_code.n):
c = self.inner_codes[i]
new_offset = offset + c.n
assert c.x_logicals is not None
assert c.z_logicals is not None
if p[i] == "X":
concatenated[offset:new_offset] = c.x_logicals[0].x_part()
concatenated[offset + self.n : new_offset + self.n] = c.x_logicals[0].z_part()
phase += c.x_logicals[0].phase
elif p[i] == "Z":
concatenated[offset:new_offset] = c.z_logicals[0].x_part()
concatenated[offset + self.n : new_offset + self.n] = c.z_logicals[0].z_part()
phase += c.z_logicals[0].phase

elif p[i] == "Y":
concatenated[offset:new_offset] = c.x_logicals[0].x_part ^ c.z_logicals[0].x_part()
concatenated[offset + self.n : new_offset + self.n] = c.x_logicals[0].z_part ^ c.z_logicals[0].z_part()
phase += c.x_logicals[0].phase + c.z_logicals[0].phase

offset = new_offset
return Pauli(concatenated, phase)


# def _valid_logicals(lst: list[StabilizerTableau | None]) -> TypeGuard[list[StabilizerTableau]]:
# return None not in lst
Fixed Show fixed Hide fixed


class ConcatenatedCSSCode(ConcatenatedCode, CSSCode):
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
"""A concatenated CSS code."""

def __init__(self, outer_code: CSSCode, inner_codes: CSSCode | Sequence[CSSCode]) -> None:
"""Initialize a concatenated CSS code.

Args:
outer_code: The outer code.
inner_codes: The inner code. If a list of codes is provided, the qubits of the outer code are encoded by the different inner codes in the list.
"""
# self.outer_code = outer_code
if isinstance(inner_codes, CSSCode):
inner_codes = [inner_codes] * outer_code.n

if not all(code.k == 1 for code in inner_codes):
msg = "The inner codes must be CSS codes with a single logical qubit."
raise InvalidStabilizerCodeError(msg)

ConcatenatedCode.__init__(self, outer_code, inner_codes)
hx = np.array([self._outer_checks_to_physical(check, "X") for check in outer_code.Hx], dtype=np.int8)
hz = np.array([self._outer_checks_to_physical(check, "Z") for check in outer_code.Hz], dtype=np.int8)
d = min(code.distance * outer_code.distance for code in inner_codes)
CSSCode.__init__(self, d, hx, hz)

def _outer_checks_to_physical(self, check: npt.NDArray[np.int8], operator: str) -> npt.NDArray[np.int8]:
"""Convert a check operator on the outer code to the operator on the concatenated code.

Args:
check: The check operator.
operator: The type of operator to be converted. Either 'X' or 'Z'.

Returns:
The check operator on the physical qubits.
"""
if check.shape[0] != self.outer_code.n:
msg = "The check operator must have the same number of qubits as the outer code."
raise InvalidStabilizerCodeError(msg)
concatenated = np.zeros((self.n), dtype=np.int8)
offset = 0
for i in range(self.outer_code.n):
c = self.inner_codes[i]
new_offset = offset + c.n
if check[i] == 1:
logical = c.Lx if operator == "X" else c.Lz
concatenated[offset:new_offset] = logical
offset = new_offset
return concatenated
56 changes: 56 additions & 0 deletions src/mqt/qecc/codes/constructions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Constructions of various known stabilizer codes."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from .css_code import CSSCode

if TYPE_CHECKING:
import numpy.typing as npt


def construct_quantum_hamming_code(r: int) -> CSSCode:
"""Return the [[2^r, 2^r-r-1, 3]] quantum Hamming code."""
h = _hamming_code_checks(r)
return CSSCode(3, h, h)


def construct_iceberg_code(m: int) -> CSSCode:
"""Return the [[2m, 2m-2, 2]] Iceberg code.

The Iceberg code is a CSS code with stabilizer generators X^2m and Z^2m.
https://errorcorrectionzoo.org/c/iceberg
"""
n = 2 * m
h = np.array([[1] * n], dtype=np.int8)
return CSSCode(2, h, h)


def construct_many_hypercube_code(level: int) -> CSSCode:
"""Return the [[6^l, 4^l, 2^l]] level l many-hypercube code (https://arxiv.org/abs/2403.16054).

This code is obtained by (l-1)-fold concatenation of the [[6,4,2]] iceberg code with itself.
"""
code = construct_iceberg_code(3)

for _ in range(1, level):
sx = np.hstack([code.Lx] * 6, dtype=np.int8)
sx_rem = np.kron(np.eye(6, dtype=np.int8), code.Hx)
sx = np.vstack((sx, sx_rem), dtype=np.int8)
sz = sx
code = CSSCode(code.distance * 2, sx, sz)
return code


def _hamming_code_checks(r: int) -> npt.NDArray[np.int8]:
"""Return the check matrix for the [2^r-1, 2^r-r-1, 3] Hamming code."""
n = 2**r - 1
h = np.zeros((r, n), dtype=int)
# columns are all binary strings up to 2^r
for i in range(1, n + 1):
h[:, i - 1] = np.array([int(x) for x in f"{i:b}".zfill(r)])

return h
49 changes: 28 additions & 21 deletions src/mqt/qecc/codes/css_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import numpy as np
from ldpc import mod2

from .pauli import StabilizerTableau
from .stabilizer_code import StabilizerCode

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt


Check warning

Code scanning / CodeQL

`__eq__` not overridden when adding attributes Warning

The class 'CSSCode' does not override '__eq__', but adds the new attribute Hx.
The class 'CSSCode' does not override '__eq__', but adds the new attribute Hz.
The class 'CSSCode' does not override '__eq__', but adds the new attribute Lx.
The class 'CSSCode' does not override '__eq__', but adds the new attribute Lz.
The class 'CSSCode' does not override '__eq__', but adds the new attribute n.
The class 'CSSCode' does not override '__eq__', but adds the new attribute Hx.
The class 'CSSCode' does not override '__eq__', but adds the new attribute Hx.
The class 'CSSCode' does not override '__eq__', but adds the new attribute n.
The class 'CSSCode' does not override '__eq__', but adds the new attribute Hz.
The class 'CSSCode' does not override '__eq__', but adds the new attribute Hz.
The class 'CSSCode' does not override '__eq__', but adds the new attribute distance.
The class 'CSSCode' does not override '__eq__', but adds the new attribute x_distance.
The class 'CSSCode' does not override '__eq__', but adds the new attribute z_distance.
The class 'CSSCode' does not override '__eq__', but adds the new attribute Lx.
The class 'CSSCode' does not override '__eq__', but adds the new attribute Lz.
class CSSCode(StabilizerCode):
"""A class for representing CSS codes."""

Expand All @@ -24,8 +25,21 @@
Hz: npt.NDArray[np.int8] | None = None, # noqa: N803
x_distance: int | None = None,
z_distance: int | None = None,
n: int | None = None,
) -> None:
"""Initialize the code."""
if Hx is None and Hz is None:
if n is None:
msg = "If no check matrices are provided, the code size must be specified."
raise InvalidCSSCodeError(msg)
self.Hx = np.zeros((0, n), dtype=np.int8)
self.Hz = np.zeros((0, n), dtype=np.int8)
self.Lx = np.eye(n, dtype=np.int8)
self.Lz = np.eye(n, dtype=np.int8)
triv = StabilizerCode.get_trivial_code(n)
super().__init__(triv.generators, triv.distance, triv.x_logicals, triv.z_logicals)
return

self._check_valid_check_matrices(Hx, Hz)

if Hx is None:
Expand All @@ -46,8 +60,8 @@

x_padded = np.hstack([self.Hx, z_padding])
z_padded = np.hstack([x_padding, self.Hz])
phases = np.zeros((x_padded.shape[0] + z_padded.shape[0], 1), dtype=np.int8)
super().__init__(np.hstack((np.vstack((x_padded, z_padded)), phases)), distance)
phases = np.zeros((x_padded.shape[0] + z_padded.shape[0]), dtype=np.int8)
super().__init__(StabilizerTableau(np.vstack((x_padded, z_padded)), phases), distance)

self.distance = distance
self.x_distance = x_distance if x_distance is not None else distance
Expand Down Expand Up @@ -88,14 +102,10 @@

def get_x_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
"""Compute the x syndrome of the error."""
if self.Hx is None:
return np.empty((0, error.shape[0]), dtype=np.int8)
return self.Hx @ error % 2

def get_z_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
"""Compute the z syndrome of the error."""
if self.Hz is None:
return np.empty((0, error.shape[0]), dtype=np.int8)
return self.Hz @ error % 2

def check_if_logical_x_error(self, residual: npt.NDArray[np.int8]) -> bool:
Expand All @@ -104,21 +114,19 @@

def check_if_x_stabilizer(self, pauli: npt.NDArray[np.int8]) -> bool:
"""Check if the Pauli is a stabilizer."""
assert self.Hx is not None
return bool(mod2.rank(np.vstack((self.Hx, pauli))) == mod2.rank(self.Hx))

def check_if_logical_z_error(self, residual: npt.NDArray[np.int8]) -> bool:
"""Check if the residual is a logical error."""
return bool((self.Lx @ residual % 2 == 1).any())
return (self.Hx.shape[0] != 0) and bool((self.Lx @ residual % 2 == 1).any())

def check_if_z_stabilizer(self, pauli: npt.NDArray[np.int8]) -> bool:
"""Check if the Pauli is a stabilizer."""
assert self.Hz is not None
return bool(mod2.rank(np.vstack((self.Hz, pauli))) == mod2.rank(self.Hz))
return (self.Hz.shape[0] != 0) and bool(mod2.rank(np.vstack((self.Hz, pauli))) == mod2.rank(self.Hz))

def stabilizer_eq_x_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool:
"""Check if two X errors are in the same coset."""
if self.Hx is None:
if self.Hx.shape[0] == 0:
return bool(np.array_equal(error_1, error_2))
m1 = np.vstack([self.Hx, error_1])
m2 = np.vstack([self.Hx, error_2])
Expand All @@ -127,7 +135,7 @@

def stabilizer_eq_z_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool:
"""Check if two Z errors are in the same coset."""
if self.Hz is None:
if self.Hz.shape[0] == 0:
return bool(np.array_equal(error_1, error_2))
m1 = np.vstack([self.Hz, error_1])
m2 = np.vstack([self.Hz, error_2])
Expand All @@ -136,19 +144,13 @@

def is_self_dual(self) -> bool:
"""Check if the code is self-dual."""
if self.Hx is None or self.Hz is None:
return False
return bool(
self.Hx.shape[0] == self.Hz.shape[0] and mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, self.Hz]))
)

@staticmethod
def _check_valid_check_matrices(Hx: npt.NDArray[np.int8] | None, Hz: npt.NDArray[np.int8] | None) -> None: # noqa: N803
"""Check if the code is a valid CSS code."""
if Hx is None and Hz is None:
msg = "At least one of the check matrices must be provided"
raise InvalidCSSCodeError(msg)

if Hx is not None and Hz is not None:
if Hx.shape[1] != Hz.shape[1]:
msg = "Check matrices must have the same number of columns"
Expand All @@ -157,18 +159,23 @@
msg = "The check matrices must be orthogonal"
raise InvalidCSSCodeError(msg)

@classmethod
def get_trivial_code(cls, n: int) -> CSSCode:
"""Return the trivial code."""
return CSSCode(1, None, None, n=n)

@staticmethod
def from_code_name(code_name: str, distance: int | None = None) -> CSSCode:
r"""Return CSSCode object for a known code.

The following codes are supported:
- [[7, 1, 3]] Steane (\"Steane\")
- [[15, 1, 3]] tetrahedral code (\"Tetrahedral\")
- [[15, 7, 3]] Hamming code (\"Hamming\")
- [[9, 1, 3]] Shore code (\"Shor\")
- [[12, 2, 4]] Carbon Code (\"Carbon\")
- [[9, 1, 3]] rotated surface code (\"Surface, 3\"), also default when no distance is given
- [[25, 1, 5]] rotated surface code (\"Surface, 5\")
- [[15, 7, 3]] Hamming code (\"Hamming\")
- [[23, 1, 7]] golay code (\"Golay\")

Args:
Expand All @@ -179,23 +186,23 @@
paths = {
"steane": prefix / "steane/",
"tetrahedral": prefix / "tetrahedral/",
"hamming": prefix / "hamming/",
"shor": prefix / "shor/",
"surface_3": prefix / "rotated_surface_d3/",
"surface_5": prefix / "rotated_surface_d5/",
"golay": prefix / "golay/",
"carbon": prefix / "carbon/",
"hamming": prefix / "hamming_15/",
}

distances = {
"steane": (3, 3),
"tetrahedral": (7, 3),
"hamming": (3, 3),
"shor": (3, 3),
"golay": (7, 7),
"surface_3": (3, 3),
"surface_5": (5, 5),
"carbon": (4, 4),
"hamming": (3, 3),
} # X, Z distances

code_name = code_name.lower()
Expand Down
File renamed without changes.
File renamed without changes.
Loading
Loading