Skip to content

Commit

Permalink
Refactored heuristic verification.
Browse files Browse the repository at this point in the history
  • Loading branch information
pehamTom committed Jul 2, 2024
1 parent f5c29c1 commit 009fa7f
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 93 deletions.
188 changes: 108 additions & 80 deletions src/mqt/qecc/ft_stateprep/state_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import logging
from collections import deque
from collections import defaultdict, deque
from typing import TYPE_CHECKING, Any

import multiprocess
Expand Down Expand Up @@ -779,7 +779,7 @@ def heuristic_verification_stabilizers(
if x_errors
else sp_circ.z_fault_sets
)

orthogonal_checks = sp_circ.z_checks if x_errors else sp_circ.x_checks
for num_errors in range(1, max_errors + 1):
logging.info(f"Finding verification stabilizers for {num_errors} errors")
faults = fault_sets[num_errors]
Expand All @@ -789,87 +789,115 @@ def heuristic_verification_stabilizers(
layers[num_errors - 1] = []
continue

orthogonal_checks = sp_circ.z_checks if x_errors else sp_circ.x_checks
syndromes = orthogonal_checks @ faults.T % 2
candidates = np.where(np.any(syndromes != 0, axis=1))[0]
non_candidates = np.where(np.all(syndromes == 0, axis=1))[0]
candidate_checks = orthogonal_checks[candidates]
non_candidate_checks = orthogonal_checks[non_candidates]

def covers(s: npt.NDArray[np.int8], faults: npt.NDArray[np.int8]) -> frozenset[int]:
return frozenset(np.where(s @ faults.T % 2 != 0)[0])

logging.info("Converting Stabilizer Checks to covering sets")
candidate_sets_ordered = [(covers(s, faults), s, i) for i, s in enumerate(candidate_checks)]
candidate_sets_ordered.sort(key=lambda x: -np.sum(x[1]))
mapping = {
cand: _coset_leader(candidate_checks[i], non_candidate_checks) for cand, _, i in candidate_sets_ordered
}
candidate_sets = {cand for cand, _, _ in candidate_sets_ordered}

def set_cover(
n: int, cands: set[frozenset[int]], mapping: dict[frozenset[int], npt.NDArray[np.int8]]
) -> list[frozenset[int]]:
universe = set(range(n))
cover = []

def sort_key(stab: frozenset[int], universe: set[int] = universe) -> tuple[int, np.int_]:
return (len(stab & universe), -np.sum(mapping[stab])) # noqa: B023

while universe:
best = max(cands, key=sort_key)
cover.append(best)
universe -= best
return cover

improved = True
logging.info("Finding initial set cover")
cover = set_cover(len(faults), candidate_sets, mapping)
logging.info(f"Initial set cover has {len(cover)} sets")
layers[num_errors - 1] = _heuristic_layer(faults, orthogonal_checks, find_coset_leaders, max_covering_sets)

return layers


def _covers(s: npt.NDArray[np.int8], faults: npt.NDArray[np.int8]) -> frozenset[int]:
return frozenset(np.where(s @ faults.T % 2 != 0)[0])


def _set_cover(
n: int, cands: set[frozenset[int]], mapping: dict[frozenset[int], list[npt.NDArray[np.int8]]]
) -> set[frozenset[int]]:
universe = set(range(n))
cover = set() # type: set[frozenset[int]]

def sort_key(stab: frozenset[int], universe: set[int] = universe) -> tuple[int, np.int_]:
return (len(stab & universe), -np.sum(mapping[stab]))

while universe:
best = max(cands, key=sort_key)
cover.add(best)
universe -= best
return cover


def _extend_covering_sets(
candidate_sets: set[frozenset[int]], size_limit: int, mapping: dict[frozenset[int], list[npt.NDArray[np.int8]]]
) -> set[frozenset[int]]:
to_remove = set() # type: set[frozenset[int]]
to_add = set() # type: set[frozenset[int]]
for c1 in candidate_sets:
for c2 in candidate_sets:
if len(to_add) >= size_limit:
break

comb = c1 ^ c2
if c1 == c2 or comb in candidate_sets or comb in to_add or comb in to_remove:
continue

mapping[comb].extend([(s1 + s2) % 2 for s1 in mapping[c1] for s2 in mapping[c2]])

if len(c1 & c2) == 0:
to_remove.add(c1)
to_remove.add(c2)
to_add.add(c1 ^ c2)

return candidate_sets.union(to_add)


def _heuristic_layer(
faults: npt.NDArray[np.int8], checks: npt.NDArray[np.int8], find_coset_leaders: bool, max_covering_sets: int
) -> list[npt.NDArray[np.int8]]:
syndromes = checks @ faults.T % 2
candidates = np.where(np.any(syndromes != 0, axis=1))[0]
non_candidates = np.where(np.all(syndromes == 0, axis=1))[0]
candidate_checks = checks[candidates]
non_candidate_checks = checks[non_candidates]

logging.info("Converting Stabilizer Checks to covering sets")
candidate_sets_ordered = [(_covers(s, faults), s, i) for i, s in enumerate(candidate_checks)]
mapping = defaultdict(list)
for cand, _, i in candidate_sets_ordered:
mapping[cand].append(candidate_checks[i])
candidate_sets = {cand for cand, _, _ in candidate_sets_ordered}

logging.info("Finding initial set cover")
cover = _set_cover(len(faults), candidate_sets, mapping)
logging.info(f"Initial set cover has {len(cover)} sets")

def cost(cover: set[frozenset[int]]) -> tuple[int, int]:
cost1 = len(cover)
cost2 = sum(np.sum(mapping[stab]) for stab in cover)
prev_candidates = candidate_sets.copy()
while improved and len(candidate_sets) < max_covering_sets:
improved = False
# add all symmetric differences to candidates
to_remove = set() # type: set[frozenset[int]]
to_add = set() # type: set[frozenset[int]]
for c1 in candidate_sets:
for c2 in candidate_sets:
if len(to_add) >= max_covering_sets:
break
comb = c1 ^ c2
if c1 == c2 or comb in candidate_sets or comb in to_add or comb in to_remove:
continue

mapping[comb] = (mapping[c1] + mapping[c2]) % 2
if len(c1 & c2) == 0:
to_remove.add(c1)
to_remove.add(c2)
to_add.add(comb)
candidate_sets = candidate_sets.union(to_add)
new_cover = set_cover(len(faults), candidate_sets, mapping)
logging.info(f"New Covering set has {len(new_cover)} sets")
new_cost1 = len(new_cover)
new_cost2 = sum(np.sum(mapping[stab]) for stab in new_cover)
if new_cost1 < cost1 or (new_cost1 == cost1 and new_cost2 < cost2):
cover = new_cover
cost1 = new_cost1
cost2 = new_cost2
improved = True
elif candidate_sets == prev_candidates:
break
prev_candidates = candidate_sets

logging.info(f"Found covering set of size {len(cover)} for {num_errors} errors")
measurements = [mapping[c] for c in cover]
if find_coset_leaders and len(non_candidates) > 0:
logging.info(f"Finding coset leaders for {num_errors} errors")
measurements = [_coset_leader(m, non_candidate_checks) for m in measurements]
logging.info(f"Found {np.sum(measurements)} CNOTS for {num_errors} errors")
layers[num_errors - 1] = measurements
return cost1, cost2

cost1, cost2 = cost(cover)
prev_candidates = candidate_sets.copy()

# find good cover
improved = True
while improved and len(candidate_sets) < max_covering_sets:
improved = False
# add all symmetric differences to candidates
candidate_sets = _extend_covering_sets(candidate_sets, max_covering_sets, mapping)
new_cover = _set_cover(len(faults), candidate_sets, mapping)
logging.info(f"New Covering set has {len(new_cover)} sets")
new_cost1 = len(new_cover)
new_cost2 = sum(np.sum(mapping[stab]) for stab in new_cover)
if new_cost1 < cost1 or (new_cost1 == cost1 and new_cost2 < cost2):
cover = new_cover
cost1 = new_cost1
cost2 = new_cost2
improved = True
elif candidate_sets == prev_candidates:
break
prev_candidates = candidate_sets

# reduce stabilizers in cover
logging.info(f"Found covering set of size {len(cover)}.")
if find_coset_leaders and len(non_candidates) > 0:
logging.info("Finding coset leaders.")
measurements = []
for c in cover:
leaders = [_coset_leader(m, non_candidate_checks) for m in mapping[c]]
leaders.sort(key=np.sum)
measurements.append(leaders[0])
else:
measurements = [min(mapping[c], key=np.sum) for c in cover]

return layers
return measurements


def _measure_ft_x(qc: QuantumCircuit, x_measurements: list[npt.NDArray[np.int8]], flags: bool = False) -> None:
Expand Down
Loading

0 comments on commit 009fa7f

Please sign in to comment.