Skip to content

Commit

Permalink
Heuristic and optimal state prep
Browse files Browse the repository at this point in the history
  • Loading branch information
pehamTom committed Jun 13, 2024
1 parent 424ad71 commit 222f02f
Show file tree
Hide file tree
Showing 3 changed files with 373 additions and 0 deletions.
57 changes: 57 additions & 0 deletions src/mqt/qecc/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Class for representing quantum error correction codes."""

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING

import numpy as np
from ldpc import mod2

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt


class CSSCode:
"""A class for representing CSS codes."""
def __init__(self, distance: int, Hx: npt.NDArray[np.int_], Hz: npt.NDArray[np.int_]):
"""Initialize the code."""
self.distance = distance

assert Hx.shape[1] == Hz.shape[1], "Hx and Hz must have the same number of columns"

self.Hx = Hx
self.Hz = Hz
self.n = Hx.shape[1]
self.k = self.n - Hx.shape[0] - Hz.shape[0]

self.Lx = CSSCode._compute_logical(self.Hx, self.Hz)
self.Lz = CSSCode._compute_logical(self.Hz, self.Hx)

def __hash__(self) -> int:
"""Compute a hash for the CSS code."""
return hash(self.Hx.tobytes() ^ self.Hz.tobytes())

def __eq__(self, other: object) -> bool:
"""Check if two CSS codes are equal."""
if not isinstance(other, CSSCode):
return NotImplemented
return mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, other.Hx])) and \
mod2.rank(self.Hz) == mod2.rank(np.vstack([self.Hz, other.Hz]))

def _compute_logical(m1: npt.NDArray[np.int_], m2: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]:
"""Compute the logical matrix L."""
ker_m1 = mod2.nullspace(m1) # compute the kernel basis of m1
im_m2_transp = mod2.row_basis(m2) # compute the image basis of m2
log_stack = np.vstack([im_m2_transp, ker_m1])
pivots = mod2.row_echelon(log_stack.T)[3]
log_op_indices = [i for i in range(im_m2_transp.shape[0], log_stack.shape[0]) if i in pivots]
return log_stack[log_op_indices]

def get_x_syndrome(self, error: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]:
"""Compute the x syndrome of the error."""
return self.Hx @ error % 2

def get_z_syndrome(self, error: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]:
"""Compute the z syndrome of the error."""
return self.Hz @ error % 2
3 changes: 3 additions & 0 deletions src/mqt/qecc/ft_stateprep/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Methods for synthesizing fault tolerant state preparation circuits."""

from __future__ import annotations
313 changes: 313 additions & 0 deletions src/mqt/qecc/ft_stateprep/state_prep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
"""Synthesizing state preparation circuits for CSS codes."""

from __future__ import annotations

from ldpc import mod2
# from code import CSSCode
import numpy as np
from qiskit import QuantumCircuit
import z3

import multiprocessing
import time

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt


def _build_circuit_from_list_and_checks(cnots: list[tuple], checks: npt.NDArray[np.int_], zero_state=True) -> QuantumCircuit:
# Build circuit
n = checks.shape[1]
circ = QuantumCircuit(n)

controls = [i for i in range(n) if np.sum(checks[:, i]) >= 1]
if zero_state:
for control in controls:
circ.h(control)
else:
for i in range(n):
if i not in controls:
circ.h(i)

for i, j in reversed(cnots):
if not zero_state:
i, j = j, i
circ.cx(i, j)

return circ


def heuristic_prep_circuit(code: CSSCode, optimize_depth: bool=True, zero_state: bool=True):
"""Return a circuit that prepares the +1 eigenstate of the code w.r.t. the Z or X basis.
Args:
code: The CSS code to prepare the state for.
zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis.
"""
checks = code.Hx.copy() if zero_state else code.Hz.copy()
rank = mod2.rank(checks)

def is_reduced():
return len(np.where(np.all(checks==0, axis=0))[0]) == checks.shape[1]-rank

costs = np.array([[np.sum((checks[:, i] + checks[:, j]) % 2) for j in range(checks.shape[1])] for i in range(checks.shape[1])])
costs -= np.sum(checks, axis=0)
np.fill_diagonal(costs, 1)

used_qubits = []
cnots = []
while not is_reduced():
m = np.zeros((checks.shape[1], checks.shape[1]), dtype=bool)
m[used_qubits, :] = True
m[:, used_qubits] = True

costs_unused = np.ma.array(costs, mask=m)
if np.all(costs_unused >= 0): # no more reductions possible
if used_qubits == []: # local minimum => get out by making matrix triangular
costs = np.array([[np.sum((checks[:, i] + checks[:, j]) % 2)
for j in range(checks.shape[1])]
for i in range(checks.shape[1])])
costs -= np.sum(checks, axis=0)
np.fill_diagonal(costs, 1)
break
checks = mod2.reduced_row_echelon(checks)[0]
else: # try to move onto the next layer
used_qubits = []
continue

i, j = np.unravel_index(np.argmin(costs_unused), costs.shape)
cnots.append((i, j))

if optimize_depth:
used_qubits.append(i)
used_qubits.append(j)

# update checks
checks[:, j] = (checks[:, i] + checks[:, j]) % 2

# update costs
new_weights = np.sum((checks[:, j][:, np.newaxis] + checks) % 2, axis=0)
costs[:, j] = new_weights - np.sum(checks, axis=0)
costs[j, :] = new_weights - np.sum(checks[:, j])
np.fill_diagonal(costs, 1)

return _build_circuit_from_list_and_checks(cnots, checks, zero_state)


def _run_with_timeout(func, *args, timeout: int=10):
"""Run a function with a timeout. If the function does not complete within the timeout, return None.
Args:
func: The function to run.
args: The arguments to pass to the function.
timeout: The maximum time to allow the function to run for in seconds."""

manager = multiprocessing.Manager()
return_list = manager.list()
p = multiprocessing.Process(target=lambda: return_list.append(func(*args)))
p.start()
p.join(timeout)
if p.is_alive():
p.terminate()
return "timeout"
return return_list[0]


def _symbolic_scalar_mult(v: npt.NDArray[np.int_], a: z3.BoolRef | bool):
"""Multiply a concrete vector by a symbolic scalar."""
return [a if s == 1 else False for s in v]


SymOrBool = z3.BoolRef | bool
SymVec = list[SymOrBool]


def _symbolic_vector_add(v1: SymVec, v2: SymVec):
"""Add two symbolic vectors."""
if v1 is None:
return v2
if v2 is None:
return v1

v_new = [False for _ in range(len(v1))]
for i in range(len(v1)):
if isinstance(v1[i], bool):
if v1[i]:
v_new[i] = z3.Not(v2[i])
else:
v_new[i] = v2[i]

elif isinstance(v2[i], bool):
if v2[i]:
v_new[i] = z3.Not(v1[i])
else:
v_new[i] = v1[i]

else:
v_new[i] = z3.Xor(v1[i], v2[i])

return v_new


def _odd_overlap(v_sym: SymVec, v_con: npt.NDArray[np.int_]):
"""Return True if the overlap of symbolic vector with constant vector is odd."""
return z3.PbEq([(v_sym[i], 1) for i, c in enumerate(v_con) if c == 1], 1)


def _generate_circ_with_bounded_depth(checks: npt.NDArray, max_depth) -> np.array:
columns = np.array([[[z3.Bool(f'x_{d}_{i}_{j}')
for j in range(checks.shape[1])]
for i in range(checks.shape[0])]
for d in range(max_depth+1)])

additions = np.array([[[z3.Bool(f'add_{d}_{i}_{j}')
for j in range(checks.shape[1])]
for i in range(checks.shape[1])]
for d in range(max_depth)])
s = z3.Solver()

# create initial matrix
columns[0, :, :] = checks.astype(bool)

# encode all possible column additions
for d in range(1, max_depth+1):
for col_1 in range(checks.shape[1]):
for col_2 in range(col_1+1, checks.shape[1]):
col_sum = _symbolic_vector_add(columns[d-1, :, col_1], columns[d-1, :, col_2])

# encode col_2 += col_1
s.add(z3.Implies(additions[d-1, col_1, col_2],
z3.And([columns[d, i, col_2] == col_sum[i]
for i in range(checks.shape[0])] +
[columns[d, i, col_1] == columns[d-1, i, col_1]
for i in range(checks.shape[0])])))
# encode col_1 += col_2
s.add(z3.Implies(additions[d-1, col_2, col_1],
z3.And([columns[d, i, col_1] == col_sum[i]
for i in range(checks.shape[0])] +
[columns[d, i, col_2] == columns[d-1, i, col_2]
for i in range(checks.shape[0])])))

# at most one addition per column
for d in range(max_depth):
for col in range(checks.shape[1]):
s.add(z3.PbLe([(additions[d, col_1, col], 1)
for col_1 in range(checks.shape[1])
if col != col_1] +
[(additions[d, col, col_2], 1)
for col_2 in range(checks.shape[1])
if col != col_2], 1
)
)

# if column is not involved in any addition at certain depth, it is the same as the previous column
for d in range(1, max_depth+1):
for col in range(checks.shape[1]):
s.add(z3.Implies(
z3.Not(z3.Or([additions[d-1, col_1, col]
for col_1 in range(checks.shape[1]) if col != col_1] +
[additions[d-1, col, col_1]
for col_1 in range(checks.shape[1]) if col != col_1])),
z3.And([columns[d, i, col] == columns[d-1, i, col]
for i in range(checks.shape[0])])))

# assert that final check matrix has checks.shape[1]-checks.shape[0] zero columns
s.add(z3.PbEq([(z3.Not(z3.Or([columns[max_depth][i][col]
for i in range(checks.shape[0])])),
1) for col in range(checks.shape[1])],
checks.shape[1]-checks.shape[0]
)
)

if s.check() == z3.sat:
m = s.model()
additions = [(i, j) for d in range(max_depth)
for j in range(checks.shape[1])
for i in range(checks.shape[1]) if m[additions[d, i, j]]]

checks = np.array([[bool(m[columns[max_depth, i, j]]) for j in range(checks.shape[1])]
for i in range(checks.shape[0])])

return additions, checks.astype(int)

return False


def iterative_search_with_timeout(fun, min_param, max_param, min_timeout, max_timeout, param_factor=2, timeout_factor=2):
"""Geometrically increases the parameter and timeout until a result is found or the maximum timeout is reached.
Args:
fun: function to run with increasing parameters and timeouts
min_param: minimum parameter to start with
max_param: maximum parameter to reach
min_timeout: minimum timeout to start with
max_timeout: maximum timeout to reach
"""
curr_timeout = min_timeout
curr_param = min_param
param_type = type(min_param)
found = False
while curr_timeout <= max_timeout:
while curr_param <= max_param:
res = _run_with_timeout(fun, curr_param, timeout=curr_timeout)
if res and res != "timeout":
return res, curr_param
curr_param = param_type(curr_param*param_factor)
curr_timeout *= 2
curr_param = min_param
return None, max_param


def depth_optimal_prep_circuit(code: CSSCode, zero_state: bool=True, min_depth=1, max_depth=10, min_timeout=1, max_timeout=3600) -> QuantumCircuit:
"""Synthesize a state preparation circuit for a CSS code that minimizes the depth of the circuit.
Args:
code: The CSS code to prepare the state for.
zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis.
starting_depth: The depth of the circuit to start with.
depth_limit: The maximum depth of the circuit to search for.
max_cnots: The maximum number of CNOT gates to allow in the circuit. If None, no limit is imposed.
"""
# first try to find any circuit by exponentially increasing depth
checks = code.Hx if zero_state else code.Hz

curr_timeout = min_timeout
curr_depth = min_depth
circ = None

# while curr_timeout <= max_timeout:
# while curr_depth <= max_depth:
# res = _run_with_timeout(_generate_circ_with_bounded_depth, checks, curr_depth, timeout=curr_timeout)
# if res and res != "timeout":
# cnots, reduced_checks = res
# circ = _build_circuit_from_list_and_checks(cnots, reduced_checks)
# break
# curr_depth *= 2
# if circ is not None:
# break
# curr_timeout *= 2
# curr_depth = min_depth

# if circ is None:
# return None
res, curr_depth = iterative_search_with_timeout(lambda depth: _generate_circ_with_bounded_depth(checks, depth), min_depth, max_depth, min_timeout, max_timeout)
if res:
cnots, reduced_checks = res
circ = _build_circuit_from_list_and_checks(cnots, reduced_checks)
else:
return None

# Solving a SAT instance is much faster than proving unsat in this case
# so we iterate backwards until we find an unsat instance or hit a timeout
curr_depth -= 1
while True:
res = _run_with_timeout(_generate_circ_with_bounded_depth, checks, curr_depth, timeout=max_timeout)
if res and res != "timeout":
cnots, reduced_checks = res
circ = _build_circuit_from_list_and_checks(cnots, reduced_checks)
else:
break
curr_depth -= 1

return circ

0 comments on commit 222f02f

Please sign in to comment.