diff --git a/src/mqt/qecc/code.py b/src/mqt/qecc/code.py new file mode 100644 index 00000000..76d1e962 --- /dev/null +++ b/src/mqt/qecc/code.py @@ -0,0 +1,57 @@ +"""Class for representing quantum error correction codes.""" + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING + +import numpy as np +from ldpc import mod2 + +if TYPE_CHECKING: # pragma: no cover + import numpy.typing as npt + + +class CSSCode: + """A class for representing CSS codes.""" + def __init__(self, distance: int, Hx: npt.NDArray[np.int_], Hz: npt.NDArray[np.int_]): + """Initialize the code.""" + self.distance = distance + + assert Hx.shape[1] == Hz.shape[1], "Hx and Hz must have the same number of columns" + + self.Hx = Hx + self.Hz = Hz + self.n = Hx.shape[1] + self.k = self.n - Hx.shape[0] - Hz.shape[0] + + self.Lx = CSSCode._compute_logical(self.Hx, self.Hz) + self.Lz = CSSCode._compute_logical(self.Hz, self.Hx) + + def __hash__(self) -> int: + """Compute a hash for the CSS code.""" + return hash(self.Hx.tobytes() ^ self.Hz.tobytes()) + + def __eq__(self, other: object) -> bool: + """Check if two CSS codes are equal.""" + if not isinstance(other, CSSCode): + return NotImplemented + return mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, other.Hx])) and \ + mod2.rank(self.Hz) == mod2.rank(np.vstack([self.Hz, other.Hz])) + + def _compute_logical(m1: npt.NDArray[np.int_], m2: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]: + """Compute the logical matrix L.""" + ker_m1 = mod2.nullspace(m1) # compute the kernel basis of m1 + im_m2_transp = mod2.row_basis(m2) # compute the image basis of m2 + log_stack = np.vstack([im_m2_transp, ker_m1]) + pivots = mod2.row_echelon(log_stack.T)[3] + log_op_indices = [i for i in range(im_m2_transp.shape[0], log_stack.shape[0]) if i in pivots] + return log_stack[log_op_indices] + + def get_x_syndrome(self, error: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]: + """Compute the x syndrome of the error.""" + return self.Hx @ error % 2 + + def get_z_syndrome(self, error: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]: + """Compute the z syndrome of the error.""" + return self.Hz @ error % 2 diff --git a/src/mqt/qecc/ft_stateprep/__init__.py b/src/mqt/qecc/ft_stateprep/__init__.py new file mode 100644 index 00000000..139afac8 --- /dev/null +++ b/src/mqt/qecc/ft_stateprep/__init__.py @@ -0,0 +1,3 @@ +"""Methods for synthesizing fault tolerant state preparation circuits.""" + +from __future__ import annotations diff --git a/src/mqt/qecc/ft_stateprep/state_prep.py b/src/mqt/qecc/ft_stateprep/state_prep.py new file mode 100644 index 00000000..8e0914ad --- /dev/null +++ b/src/mqt/qecc/ft_stateprep/state_prep.py @@ -0,0 +1,313 @@ +"""Synthesizing state preparation circuits for CSS codes.""" + +from __future__ import annotations + +from ldpc import mod2 +# from code import CSSCode +import numpy as np +from qiskit import QuantumCircuit +import z3 + +import multiprocessing +import time + +if TYPE_CHECKING: # pragma: no cover + import numpy.typing as npt + + +def _build_circuit_from_list_and_checks(cnots: list[tuple], checks: npt.NDArray[np.int_], zero_state=True) -> QuantumCircuit: + # Build circuit + n = checks.shape[1] + circ = QuantumCircuit(n) + + controls = [i for i in range(n) if np.sum(checks[:, i]) >= 1] + if zero_state: + for control in controls: + circ.h(control) + else: + for i in range(n): + if i not in controls: + circ.h(i) + + for i, j in reversed(cnots): + if not zero_state: + i, j = j, i + circ.cx(i, j) + + return circ + + +def heuristic_prep_circuit(code: CSSCode, optimize_depth: bool=True, zero_state: bool=True): + """Return a circuit that prepares the +1 eigenstate of the code w.r.t. the Z or X basis. + + Args: + code: The CSS code to prepare the state for. + zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis. + """ + checks = code.Hx.copy() if zero_state else code.Hz.copy() + rank = mod2.rank(checks) + + def is_reduced(): + return len(np.where(np.all(checks==0, axis=0))[0]) == checks.shape[1]-rank + + costs = np.array([[np.sum((checks[:, i] + checks[:, j]) % 2) for j in range(checks.shape[1])] for i in range(checks.shape[1])]) + costs -= np.sum(checks, axis=0) + np.fill_diagonal(costs, 1) + + used_qubits = [] + cnots = [] + while not is_reduced(): + m = np.zeros((checks.shape[1], checks.shape[1]), dtype=bool) + m[used_qubits, :] = True + m[:, used_qubits] = True + + costs_unused = np.ma.array(costs, mask=m) + if np.all(costs_unused >= 0): # no more reductions possible + if used_qubits == []: # local minimum => get out by making matrix triangular + costs = np.array([[np.sum((checks[:, i] + checks[:, j]) % 2) + for j in range(checks.shape[1])] + for i in range(checks.shape[1])]) + costs -= np.sum(checks, axis=0) + np.fill_diagonal(costs, 1) + break + checks = mod2.reduced_row_echelon(checks)[0] + else: # try to move onto the next layer + used_qubits = [] + continue + + i, j = np.unravel_index(np.argmin(costs_unused), costs.shape) + cnots.append((i, j)) + + if optimize_depth: + used_qubits.append(i) + used_qubits.append(j) + + # update checks + checks[:, j] = (checks[:, i] + checks[:, j]) % 2 + + # update costs + new_weights = np.sum((checks[:, j][:, np.newaxis] + checks) % 2, axis=0) + costs[:, j] = new_weights - np.sum(checks, axis=0) + costs[j, :] = new_weights - np.sum(checks[:, j]) + np.fill_diagonal(costs, 1) + + return _build_circuit_from_list_and_checks(cnots, checks, zero_state) + + +def _run_with_timeout(func, *args, timeout: int=10): + """Run a function with a timeout. If the function does not complete within the timeout, return None. + + Args: + func: The function to run. + args: The arguments to pass to the function. + timeout: The maximum time to allow the function to run for in seconds.""" + + manager = multiprocessing.Manager() + return_list = manager.list() + p = multiprocessing.Process(target=lambda: return_list.append(func(*args))) + p.start() + p.join(timeout) + if p.is_alive(): + p.terminate() + return "timeout" + return return_list[0] + + +def _symbolic_scalar_mult(v: npt.NDArray[np.int_], a: z3.BoolRef | bool): + """Multiply a concrete vector by a symbolic scalar.""" + return [a if s == 1 else False for s in v] + + +SymOrBool = z3.BoolRef | bool +SymVec = list[SymOrBool] + + +def _symbolic_vector_add(v1: SymVec, v2: SymVec): + """Add two symbolic vectors.""" + if v1 is None: + return v2 + if v2 is None: + return v1 + + v_new = [False for _ in range(len(v1))] + for i in range(len(v1)): + if isinstance(v1[i], bool): + if v1[i]: + v_new[i] = z3.Not(v2[i]) + else: + v_new[i] = v2[i] + + elif isinstance(v2[i], bool): + if v2[i]: + v_new[i] = z3.Not(v1[i]) + else: + v_new[i] = v1[i] + + else: + v_new[i] = z3.Xor(v1[i], v2[i]) + + return v_new + + +def _odd_overlap(v_sym: SymVec, v_con: npt.NDArray[np.int_]): + """Return True if the overlap of symbolic vector with constant vector is odd.""" + return z3.PbEq([(v_sym[i], 1) for i, c in enumerate(v_con) if c == 1], 1) + + +def _generate_circ_with_bounded_depth(checks: npt.NDArray, max_depth) -> np.array: + columns = np.array([[[z3.Bool(f'x_{d}_{i}_{j}') + for j in range(checks.shape[1])] + for i in range(checks.shape[0])] + for d in range(max_depth+1)]) + + additions = np.array([[[z3.Bool(f'add_{d}_{i}_{j}') + for j in range(checks.shape[1])] + for i in range(checks.shape[1])] + for d in range(max_depth)]) + s = z3.Solver() + + # create initial matrix + columns[0, :, :] = checks.astype(bool) + + # encode all possible column additions + for d in range(1, max_depth+1): + for col_1 in range(checks.shape[1]): + for col_2 in range(col_1+1, checks.shape[1]): + col_sum = _symbolic_vector_add(columns[d-1, :, col_1], columns[d-1, :, col_2]) + + # encode col_2 += col_1 + s.add(z3.Implies(additions[d-1, col_1, col_2], + z3.And([columns[d, i, col_2] == col_sum[i] + for i in range(checks.shape[0])] + + [columns[d, i, col_1] == columns[d-1, i, col_1] + for i in range(checks.shape[0])]))) + # encode col_1 += col_2 + s.add(z3.Implies(additions[d-1, col_2, col_1], + z3.And([columns[d, i, col_1] == col_sum[i] + for i in range(checks.shape[0])] + + [columns[d, i, col_2] == columns[d-1, i, col_2] + for i in range(checks.shape[0])]))) + + # at most one addition per column + for d in range(max_depth): + for col in range(checks.shape[1]): + s.add(z3.PbLe([(additions[d, col_1, col], 1) + for col_1 in range(checks.shape[1]) + if col != col_1] + + [(additions[d, col, col_2], 1) + for col_2 in range(checks.shape[1]) + if col != col_2], 1 + ) + ) + + # if column is not involved in any addition at certain depth, it is the same as the previous column + for d in range(1, max_depth+1): + for col in range(checks.shape[1]): + s.add(z3.Implies( + z3.Not(z3.Or([additions[d-1, col_1, col] + for col_1 in range(checks.shape[1]) if col != col_1] + + [additions[d-1, col, col_1] + for col_1 in range(checks.shape[1]) if col != col_1])), + z3.And([columns[d, i, col] == columns[d-1, i, col] + for i in range(checks.shape[0])]))) + + # assert that final check matrix has checks.shape[1]-checks.shape[0] zero columns + s.add(z3.PbEq([(z3.Not(z3.Or([columns[max_depth][i][col] + for i in range(checks.shape[0])])), + 1) for col in range(checks.shape[1])], + checks.shape[1]-checks.shape[0] + ) + ) + + if s.check() == z3.sat: + m = s.model() + additions = [(i, j) for d in range(max_depth) + for j in range(checks.shape[1]) + for i in range(checks.shape[1]) if m[additions[d, i, j]]] + + checks = np.array([[bool(m[columns[max_depth, i, j]]) for j in range(checks.shape[1])] + for i in range(checks.shape[0])]) + + return additions, checks.astype(int) + + return False + + +def iterative_search_with_timeout(fun, min_param, max_param, min_timeout, max_timeout, param_factor=2, timeout_factor=2): + """Geometrically increases the parameter and timeout until a result is found or the maximum timeout is reached. + + Args: + fun: function to run with increasing parameters and timeouts + min_param: minimum parameter to start with + max_param: maximum parameter to reach + min_timeout: minimum timeout to start with + max_timeout: maximum timeout to reach + """ + curr_timeout = min_timeout + curr_param = min_param + param_type = type(min_param) + found = False + while curr_timeout <= max_timeout: + while curr_param <= max_param: + res = _run_with_timeout(fun, curr_param, timeout=curr_timeout) + if res and res != "timeout": + return res, curr_param + curr_param = param_type(curr_param*param_factor) + curr_timeout *= 2 + curr_param = min_param + return None, max_param + + +def depth_optimal_prep_circuit(code: CSSCode, zero_state: bool=True, min_depth=1, max_depth=10, min_timeout=1, max_timeout=3600) -> QuantumCircuit: + """Synthesize a state preparation circuit for a CSS code that minimizes the depth of the circuit. + + Args: + code: The CSS code to prepare the state for. + zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis. + starting_depth: The depth of the circuit to start with. + depth_limit: The maximum depth of the circuit to search for. + max_cnots: The maximum number of CNOT gates to allow in the circuit. If None, no limit is imposed. + """ + # first try to find any circuit by exponentially increasing depth + checks = code.Hx if zero_state else code.Hz + + curr_timeout = min_timeout + curr_depth = min_depth + circ = None + + # while curr_timeout <= max_timeout: + # while curr_depth <= max_depth: + # res = _run_with_timeout(_generate_circ_with_bounded_depth, checks, curr_depth, timeout=curr_timeout) + # if res and res != "timeout": + # cnots, reduced_checks = res + # circ = _build_circuit_from_list_and_checks(cnots, reduced_checks) + # break + # curr_depth *= 2 + # if circ is not None: + # break + # curr_timeout *= 2 + # curr_depth = min_depth + + # if circ is None: + # return None + res, curr_depth = iterative_search_with_timeout(lambda depth: _generate_circ_with_bounded_depth(checks, depth), min_depth, max_depth, min_timeout, max_timeout) + if res: + cnots, reduced_checks = res + circ = _build_circuit_from_list_and_checks(cnots, reduced_checks) + else: + return None + + # Solving a SAT instance is much faster than proving unsat in this case + # so we iterate backwards until we find an unsat instance or hit a timeout + curr_depth -= 1 + while True: + res = _run_with_timeout(_generate_circ_with_bounded_depth, checks, curr_depth, timeout=max_timeout) + if res and res != "timeout": + cnots, reduced_checks = res + circ = _build_circuit_from_list_and_checks(cnots, reduced_checks) + else: + break + curr_depth -= 1 + + return circ +