Skip to content

Commit

Permalink
Added logging and fixed bug with logical operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
pehamTom committed Jun 21, 2024
1 parent f6f112c commit 347a849
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 33 deletions.
15 changes: 9 additions & 6 deletions src/mqt/qecc/ft_stateprep/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ def __init__(self, distance: int, Hx: npt.NDArray[np.int_], Hz: npt.NDArray[np.i
self.Hz = Hz
self.n = Hx.shape[1]
self.k = self.n - Hx.shape[0] - Hz.shape[0]

self.Lx = CSSCode._compute_logical(self.Hx, self.Hz)
self.Lz = CSSCode._compute_logical(self.Hz, self.Hx)
self.Lx = CSSCode._compute_logical(self.Hz, self.Hx)
self.Lz = CSSCode._compute_logical(self.Hx, self.Hz)

def __hash__(self) -> int:
"""Compute a hash for the CSS code."""
Expand Down Expand Up @@ -92,9 +91,9 @@ def from_code_name(code_name: str, distance: int=None) -> CSSCode:
"tetrahedral": prefix / "tetrahedral/",
"hamming": prefix / "hamming/",
"shor": prefix / "shor/",
"surface, 3": prefix / "rotated_surface_d3/",
"surface, 5": prefix / "rotated_surface_d5/",
"cc_4_8_8 5": prefix / "cc_4_8_8_d5/",
"surface_3": prefix / "rotated_surface_d3/",
"surface_5": prefix / "rotated_surface_d5/",
"cc_4_8_8": prefix / "cc_4_8_8_d5/",
}

distances = {
Expand All @@ -106,6 +105,10 @@ def from_code_name(code_name: str, distance: int=None) -> CSSCode:
}

code_name = code_name.lower()
if code_name == "surface":
code_name = code_name + "_%d" % distance


if code_name == "cc_6_6_6":
if distance is None:
raise ValueError("Distance is not specified for CC_6_6_6")
Expand Down
62 changes: 35 additions & 27 deletions src/mqt/qecc/ft_stateprep/state_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from qiskit.dagcircuit import DAGOutNode
import z3
import multiprocessing
import logging

from .code import CSSCode

logger = logging.getLogger(__name__)

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt
SymOrBool = z3.BoolRef | bool
Expand Down Expand Up @@ -87,7 +90,6 @@ def compute_fault_set(self, n_errors=1, reduce=True) -> npt.NDArray[np.bool_]:
faults[i] = reduced_fault
reduced = True
break

# remove trivial faults
faults = np.array(faults)
faults = faults[np.where(np.sum(faults, axis=1) > (self.code.distance-1) // 2)[0]]
Expand All @@ -105,6 +107,7 @@ def heuristic_prep_circuit(code: CSSCode, optimize_depth: bool = True, zero_stat
code: The CSS code to prepare the state for.
zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis.
"""
logging.info("Starting heuristic state preparation.")
checks = code.Hx.copy() if zero_state else code.Hz.copy()
rank = mod2.rank(checks)

Expand All @@ -127,13 +130,13 @@ def is_reduced():
costs_unused = np.ma.array(costs, mask=m)
if np.all(costs_unused >= 0): # no more reductions possible
if used_qubits == []: # local minimum => get out by making matrix triangular
logging.warning("Local minimum reached. Making matrix triangular.")
checks = mod2.reduced_row_echelon(checks)[0]
costs = np.array([[np.sum((checks[:, i] + checks[:, j]) % 2)
for j in range(checks.shape[1])]
for i in range(checks.shape[1])])
costs -= np.sum(checks, axis=0)
np.fill_diagonal(costs, 1)
break
checks = mod2.reduced_row_echelon(checks)[0]
else: # try to move onto the next layer
used_qubits = []
continue
Expand All @@ -147,11 +150,10 @@ def is_reduced():

# update checks
checks[:, j] = (checks[:, i] + checks[:, j]) % 2

# update costs
new_weights = np.sum((checks[:, j][:, np.newaxis] + checks) % 2, axis=0)
costs[:, j] = new_weights - np.sum(checks, axis=0)
costs[j, :] = new_weights - np.sum(checks[:, j])
costs[j, :] = new_weights - np.sum(checks, axis=0)
costs[:, j] = new_weights - np.sum(checks[:, j])
np.fill_diagonal(costs, 1)

circ = _build_circuit_from_list_and_checks(cnots, checks, zero_state)
Expand Down Expand Up @@ -406,41 +408,48 @@ def gate_optimal_verification_stabilizers(sp_circ: StatePrepCircuit, n_errors: i
layers = [None for _ in range(n_errors)]
# Find the optimal circuit for every number of errors in the preparation circuit
for num_errors in range(1, (sp_circ.code.distance-1) // 2 + 1):
logging.info(f"Finding verification stabilizers for {num_errors} errors")
# Start with maximal number of ancillas
# Minimal CNOT solution must be achievable with these
num_anc = sp_circ.max_measurements
min_cnots = np.min(np.sum(sp_circ.orthogonal_checks, axis=1))
max_cnots = np.sum(sp_circ.orthogonal_checks)

logging.info(f"Finding verification stabilizers for {num_errors} errors with {min_cnots} to {max_cnots} CNOTs using {num_anc} ancillas")
measurements, num_cnots = iterative_search_with_timeout(lambda num_cnots: verification_stabilizers(sp_circ, num_anc, num_cnots, num_errors), min_cnots, max_cnots, min_timeout, max_timeout)

if measurements is None or (isinstance(measurements, str) and measurements == "timeout"):
logging.info(f"No verification stabilizers found for {num_errors} errors")
return None # No solution found

logging.info(f"Found verification stabilizers for {num_errors} errors with {num_cnots} CNOTs")
# If any measurements are unused we can reduce the number of ancillas at least by that
num_anc = np.sum([np.any(m) for m in measurements])
measurements = [m for m in measurements if np.any(m)]

# Iterate backwards to find the minimal number of cnots
num_cnots -= 1
while num_cnots > 0:
res = verification_stabilizers(sp_circ, num_anc, num_cnots, num_errors)
if res is None or res == "timeout":
logging.info(f"Finding minimal number of CNOTs for {num_errors} errors")
while num_cnots-1 > 0:
logging.info(f"Trying {num_cnots} CNOTs")
res = verification_stabilizers(sp_circ, num_anc, num_cnots-1, num_errors)
if res is None or isinstance(res, str) and res == "timeout":
break
num_cnots -= 1
measurements = res
logging.info(f"Found minimal number of CNOTs for {num_errors} errors: {num_cnots}")

# If the number of CNOTs is minimal, we can reduce the number of ancillas
num_anc -= 1
while num_anc > 0:
res = verification_stabilizers(sp_circ, num_anc, num_cnots, num_errors)
if res is None or res == "timeout":
logging.info(f"Finding minimal number of ancillas for {num_errors} errors")
while num_anc-1 > 0:
logging.info(f"Trying {num_anc} ancillas")
res = verification_stabilizers(sp_circ, num_anc-1, num_cnots, num_errors)
if res is None or isinstance(res, str) and res == "timeout":
break
num_anc -= 1
measurements = res

logging.info(f"Found minimal number of ancillas for {num_errors} errors: {num_anc}")
layers[num_errors-1] = measurements

return layers


Expand Down Expand Up @@ -478,7 +487,7 @@ def _measure_stabs(circ: QuantumCircuit, measurements: list(npt.NDArray[np.int_]
current_anc += 1
return measured_circ

def verification_stabilizers(self, num_anc, num_cnots, num_errors):
def verification_stabilizers(sp_circ: StatePrepCircuit, num_anc, num_cnots, num_errors):
"""Return a verification circuit for the state preparation circuit.
Args:
Expand All @@ -489,43 +498,42 @@ def verification_stabilizers(self, num_anc, num_cnots, num_errors):
# Measurements are written as sums of generators
# The variables indicate which generators are non-zero in the sum
measurement_vars = [[z3.Bool("m_{}_{}".format(anc, i))
for i in range(self.max_measurements)]
for i in range(sp_circ.max_measurements)]
for anc in range(num_anc)]
solver = z3.Solver()

def vars_to_stab(measurement):
measurement_stab = _symbolic_scalar_mult(self.orthogonal_checks[0], measurement[0])
measurement_stab = _symbolic_scalar_mult(sp_circ.orthogonal_checks[0], measurement[0])
for i, scalar in enumerate(measurement[1:]):
measurement_stab = _symbolic_vector_add(measurement_stab, _symbolic_scalar_mult(self.orthogonal_checks[i+1], scalar))
measurement_stab = _symbolic_vector_add(measurement_stab, _symbolic_scalar_mult(sp_circ.orthogonal_checks[i+1], scalar))
return measurement_stab

measurement_stabs = [vars_to_stab(vars_) for vars_ in measurement_vars]

# assert that each error is detected
errors = self.compute_fault_set(num_errors)
errors = sp_circ.compute_fault_set(num_errors)
solver.add(z3.And([z3.PbGe([(_odd_overlap(measurement, error), 1)
for measurement in measurement_stabs], 1)
for error in errors]))

# assert that not too many CNOTs are used
solver.add(z3.PbLe([(measurement[q], 1)
for measurement in measurement_stabs
for q in range(self.num_qubits)],
for q in range(sp_circ.num_qubits)],
num_cnots))

if solver.check() == z3.sat:
model = solver.model()
# Extract stabilizer measurements from model
actual_measurements = []
for m in measurement_vars:
v = np.zeros(self.num_qubits, dtype=int)
for g in range(self.max_measurements):
v = np.zeros(sp_circ.num_qubits, dtype=int)
for g in range(sp_circ.max_measurements):
if model[m[g]]:
v += self.orthogonal_checks[g]
v += sp_circ.orthogonal_checks[g]
actual_measurements.append(v % 2)

return np.array(actual_measurements)

return None


Expand Down

0 comments on commit 347a849

Please sign in to comment.