diff --git a/src/mqt/qecc/ft_stateprep/state_prep.py b/src/mqt/qecc/ft_stateprep/state_prep.py index 576dd6db..94de4899 100644 --- a/src/mqt/qecc/ft_stateprep/state_prep.py +++ b/src/mqt/qecc/ft_stateprep/state_prep.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from collections import deque +from collections import defaultdict, deque from typing import TYPE_CHECKING, Any import multiprocess @@ -779,7 +779,7 @@ def heuristic_verification_stabilizers( if x_errors else sp_circ.z_fault_sets ) - + orthogonal_checks = sp_circ.z_checks if x_errors else sp_circ.x_checks for num_errors in range(1, max_errors + 1): logging.info(f"Finding verification stabilizers for {num_errors} errors") faults = fault_sets[num_errors] @@ -789,87 +789,115 @@ def heuristic_verification_stabilizers( layers[num_errors - 1] = [] continue - orthogonal_checks = sp_circ.z_checks if x_errors else sp_circ.x_checks - syndromes = orthogonal_checks @ faults.T % 2 - candidates = np.where(np.any(syndromes != 0, axis=1))[0] - non_candidates = np.where(np.all(syndromes == 0, axis=1))[0] - candidate_checks = orthogonal_checks[candidates] - non_candidate_checks = orthogonal_checks[non_candidates] - - def covers(s: npt.NDArray[np.int8], faults: npt.NDArray[np.int8]) -> frozenset[int]: - return frozenset(np.where(s @ faults.T % 2 != 0)[0]) - - logging.info("Converting Stabilizer Checks to covering sets") - candidate_sets_ordered = [(covers(s, faults), s, i) for i, s in enumerate(candidate_checks)] - candidate_sets_ordered.sort(key=lambda x: -np.sum(x[1])) - mapping = { - cand: _coset_leader(candidate_checks[i], non_candidate_checks) for cand, _, i in candidate_sets_ordered - } - candidate_sets = {cand for cand, _, _ in candidate_sets_ordered} - - def set_cover( - n: int, cands: set[frozenset[int]], mapping: dict[frozenset[int], npt.NDArray[np.int8]] - ) -> list[frozenset[int]]: - universe = set(range(n)) - cover = [] - - def sort_key(stab: frozenset[int], universe: set[int] = universe) -> tuple[int, np.int_]: - return (len(stab & universe), -np.sum(mapping[stab])) # noqa: B023 - - while universe: - best = max(cands, key=sort_key) - cover.append(best) - universe -= best - return cover - - improved = True - logging.info("Finding initial set cover") - cover = set_cover(len(faults), candidate_sets, mapping) - logging.info(f"Initial set cover has {len(cover)} sets") + layers[num_errors - 1] = _heuristic_layer(faults, orthogonal_checks, find_coset_leaders, max_covering_sets) + + return layers + + +def _covers(s: npt.NDArray[np.int8], faults: npt.NDArray[np.int8]) -> frozenset[int]: + return frozenset(np.where(s @ faults.T % 2 != 0)[0]) + + +def _set_cover( + n: int, cands: set[frozenset[int]], mapping: dict[frozenset[int], list[npt.NDArray[np.int8]]] +) -> set[frozenset[int]]: + universe = set(range(n)) + cover = set() # type: set[frozenset[int]] + + def sort_key(stab: frozenset[int], universe: set[int] = universe) -> tuple[int, np.int_]: + return (len(stab & universe), -np.sum(mapping[stab])) + + while universe: + best = max(cands, key=sort_key) + cover.add(best) + universe -= best + return cover + + +def _extend_covering_sets( + candidate_sets: set[frozenset[int]], size_limit: int, mapping: dict[frozenset[int], list[npt.NDArray[np.int8]]] +) -> set[frozenset[int]]: + to_remove = set() # type: set[frozenset[int]] + to_add = set() # type: set[frozenset[int]] + for c1 in candidate_sets: + for c2 in candidate_sets: + if len(to_add) >= size_limit: + break + + comb = c1 ^ c2 + if c1 == c2 or comb in candidate_sets or comb in to_add or comb in to_remove: + continue + + mapping[comb].extend([(s1 + s2) % 2 for s1 in mapping[c1] for s2 in mapping[c2]]) + + if len(c1 & c2) == 0: + to_remove.add(c1) + to_remove.add(c2) + to_add.add(c1 ^ c2) + + return candidate_sets.union(to_add) + + +def _heuristic_layer( + faults: npt.NDArray[np.int8], checks: npt.NDArray[np.int8], find_coset_leaders: bool, max_covering_sets: int +) -> list[npt.NDArray[np.int8]]: + syndromes = checks @ faults.T % 2 + candidates = np.where(np.any(syndromes != 0, axis=1))[0] + non_candidates = np.where(np.all(syndromes == 0, axis=1))[0] + candidate_checks = checks[candidates] + non_candidate_checks = checks[non_candidates] + + logging.info("Converting Stabilizer Checks to covering sets") + candidate_sets_ordered = [(_covers(s, faults), s, i) for i, s in enumerate(candidate_checks)] + mapping = defaultdict(list) + for cand, _, i in candidate_sets_ordered: + mapping[cand].append(candidate_checks[i]) + candidate_sets = {cand for cand, _, _ in candidate_sets_ordered} + + logging.info("Finding initial set cover") + cover = _set_cover(len(faults), candidate_sets, mapping) + logging.info(f"Initial set cover has {len(cover)} sets") + + def cost(cover: set[frozenset[int]]) -> tuple[int, int]: cost1 = len(cover) cost2 = sum(np.sum(mapping[stab]) for stab in cover) - prev_candidates = candidate_sets.copy() - while improved and len(candidate_sets) < max_covering_sets: - improved = False - # add all symmetric differences to candidates - to_remove = set() # type: set[frozenset[int]] - to_add = set() # type: set[frozenset[int]] - for c1 in candidate_sets: - for c2 in candidate_sets: - if len(to_add) >= max_covering_sets: - break - comb = c1 ^ c2 - if c1 == c2 or comb in candidate_sets or comb in to_add or comb in to_remove: - continue - - mapping[comb] = (mapping[c1] + mapping[c2]) % 2 - if len(c1 & c2) == 0: - to_remove.add(c1) - to_remove.add(c2) - to_add.add(comb) - candidate_sets = candidate_sets.union(to_add) - new_cover = set_cover(len(faults), candidate_sets, mapping) - logging.info(f"New Covering set has {len(new_cover)} sets") - new_cost1 = len(new_cover) - new_cost2 = sum(np.sum(mapping[stab]) for stab in new_cover) - if new_cost1 < cost1 or (new_cost1 == cost1 and new_cost2 < cost2): - cover = new_cover - cost1 = new_cost1 - cost2 = new_cost2 - improved = True - elif candidate_sets == prev_candidates: - break - prev_candidates = candidate_sets - - logging.info(f"Found covering set of size {len(cover)} for {num_errors} errors") - measurements = [mapping[c] for c in cover] - if find_coset_leaders and len(non_candidates) > 0: - logging.info(f"Finding coset leaders for {num_errors} errors") - measurements = [_coset_leader(m, non_candidate_checks) for m in measurements] - logging.info(f"Found {np.sum(measurements)} CNOTS for {num_errors} errors") - layers[num_errors - 1] = measurements + return cost1, cost2 + + cost1, cost2 = cost(cover) + prev_candidates = candidate_sets.copy() + + # find good cover + improved = True + while improved and len(candidate_sets) < max_covering_sets: + improved = False + # add all symmetric differences to candidates + candidate_sets = _extend_covering_sets(candidate_sets, max_covering_sets, mapping) + new_cover = _set_cover(len(faults), candidate_sets, mapping) + logging.info(f"New Covering set has {len(new_cover)} sets") + new_cost1 = len(new_cover) + new_cost2 = sum(np.sum(mapping[stab]) for stab in new_cover) + if new_cost1 < cost1 or (new_cost1 == cost1 and new_cost2 < cost2): + cover = new_cover + cost1 = new_cost1 + cost2 = new_cost2 + improved = True + elif candidate_sets == prev_candidates: + break + prev_candidates = candidate_sets + + # reduce stabilizers in cover + logging.info(f"Found covering set of size {len(cover)}.") + if find_coset_leaders and len(non_candidates) > 0: + logging.info("Finding coset leaders.") + measurements = [] + for c in cover: + leaders = [_coset_leader(m, non_candidate_checks) for m in mapping[c]] + leaders.sort(key=np.sum) + measurements.append(leaders[0]) + else: + measurements = [min(mapping[c], key=np.sum) for c in cover] - return layers + return measurements def _measure_ft_x(qc: QuantumCircuit, x_measurements: list[npt.NDArray[np.int8]], flags: bool = False) -> None: diff --git a/test/python/ft_stateprep/test_stateprep.py b/test/python/ft_stateprep/test_stateprep.py index 12f3872d..fa727b8e 100644 --- a/test/python/ft_stateprep/test_stateprep.py +++ b/test/python/ft_stateprep/test_stateprep.py @@ -27,19 +27,49 @@ @pytest.fixture() -def steane_code_sp() -> StatePrepCircuit: +def steane_code() -> CSSCode: + """Return the Steane code.""" + return CSSCode.from_code_name("Steane") + + +@pytest.fixture() +def surface_code() -> CSSCode: + """Return the distance 3 rotated Surface Code.""" + return CSSCode.from_code_name("surface", 3) + + +@pytest.fixture() +def tetrahedral_code() -> CSSCode: + """Return the tetrahedral code.""" + return CSSCode.from_code_name("tetrahedral") + + +@pytest.fixture() +def cc_4_8_8_code() -> CSSCode: + """Return the d=5 4,8,8 color code.""" + return CSSCode.from_code_name("cc_4_8_8") + + +@pytest.fixture() +def steane_code_sp(steane_code: CSSCode) -> StatePrepCircuit: """Return a non-ft state preparation circuit for the Steane code.""" - code = CSSCode.from_code_name("Steane") - sp_circ = heuristic_prep_circuit(code) + sp_circ = heuristic_prep_circuit(steane_code) + sp_circ.compute_fault_sets() + return sp_circ + + +@pytest.fixture() +def tetrahedral_code_sp(tetrahedral_code: CSSCode) -> StatePrepCircuit: + """Return a non-ft state preparation circuit for the tetrahedral code.""" + sp_circ = heuristic_prep_circuit(tetrahedral_code) sp_circ.compute_fault_sets() return sp_circ @pytest.fixture() -def color_code_d5_sp() -> StatePrepCircuit: +def color_code_d5_sp(cc_4_8_8_code: CSSCode) -> StatePrepCircuit: """Return a non-ft state preparation circuit for the d=5 4,8,8 color code.""" - code = CSSCode.from_code_name("cc_4_8_8") - sp_circ = heuristic_prep_circuit(code) + sp_circ = heuristic_prep_circuit(cc_4_8_8_code) sp_circ.compute_fault_sets() return sp_circ @@ -81,13 +111,14 @@ def test_heuristic_prep_consistent(code_name: str) -> None: assert eq_span(np.vstack((code.Hz, code.Lz)), z) # type: ignore[arg-type] -@pytest.mark.parametrize("code_name", ["steane", "surface"]) -def test_gate_optimal_prep_consistent(code_name: str) -> None: +@pytest.mark.parametrize("code", ["steane_code", "surface_code"]) +def test_gate_optimal_prep_consistent(code: CSSCode, request) -> None: # type: ignore[no-untyped-def] """Check that gate_optimal_prep_circuit returns a valid circuit with the correct stabilizers.""" - code = CSSCode.from_code_name(code_name) - + code = request.getfixturevalue(code) sp_circ = gate_optimal_prep_circuit(code, max_timeout=2) assert sp_circ is not None + assert sp_circ.zero_state + circ = sp_circ.circ max_cnots = np.sum(code.Hx) + np.sum(code.Hz) # type: ignore[arg-type] @@ -99,10 +130,10 @@ def test_gate_optimal_prep_consistent(code_name: str) -> None: assert eq_span(np.vstack((code.Hz, code.Lz)), z) # type: ignore[arg-type] -@pytest.mark.parametrize("code_name", ["steane", "surface"]) -def test_depth_optimal_prep_consistent(code_name: str) -> None: +@pytest.mark.parametrize("code", ["steane", "surface"]) +def test_depth_optimal_prep_consistent(code: CSSCode, request) -> None: # type: ignore[no-untyped-def] """Check that depth_optimal_prep_circuit returns a valid circuit with the correct stabilizers.""" - code = CSSCode.from_code_name(code_name) + code = request.getfixturevalue(code) sp_circ = gate_optimal_prep_circuit(code, max_timeout=2) assert sp_circ is not None @@ -117,6 +148,71 @@ def test_depth_optimal_prep_consistent(code_name: str) -> None: assert eq_span(np.vstack((code.Hz, code.Lz)), z) # type: ignore[arg-type] +@pytest.mark.parametrize("code", ["steane_code", "surface_code"]) +def test_plus_state_gate_optimal(code: CSSCode, request) -> None: # type: ignore[no-untyped-def] + """Test synthesis of the plus state.""" + code = request.getfixturevalue(code) + sp_circ_plus = gate_optimal_prep_circuit(code, max_timeout=2, zero_state=False) + + assert sp_circ_plus is not None + assert not sp_circ_plus.zero_state + + circ_plus = sp_circ_plus.circ + max_cnots = np.sum(code.Hx) + np.sum(code.Hz) # type: ignore[arg-type] + + assert circ_plus.num_qubits == code.n + assert circ_plus.num_nonlocal_gates() <= max_cnots + + x, z = get_stabs(circ_plus) + assert eq_span(code.Hz, z) # type: ignore[arg-type] + assert eq_span(np.vstack((code.Hx, code.Lx)), x) # type: ignore[arg-type] + + sp_circ_zero = gate_optimal_prep_circuit(code, max_timeout=2, zero_state=True) + + assert sp_circ_zero is not None + + circ_zero = sp_circ_zero.circ + x_zero, z_zero = get_stabs(circ_zero) + + if code.is_self_dual(): + assert np.array_equal(x, z_zero) + assert np.array_equal(z, x_zero) + else: + assert not np.array_equal(x, z_zero) + assert np.array_equal(z, x_zero) + + +@pytest.mark.parametrize("code", ["steane_code", "surface_code", "tetrahedral_code"]) +def test_plus_state_heuristic(code: CSSCode, request) -> None: # type: ignore[no-untyped-def] + """Test synthesis of the plus state.""" + code = request.getfixturevalue(code) + sp_circ_plus = heuristic_prep_circuit(code, zero_state=False) + + assert sp_circ_plus is not None + assert not sp_circ_plus.zero_state + + circ_plus = sp_circ_plus.circ + max_cnots = np.sum(code.Hx) + np.sum(code.Hz) # type: ignore[arg-type] + + assert circ_plus.num_qubits == code.n + assert circ_plus.num_nonlocal_gates() <= max_cnots + + x, z = get_stabs(circ_plus) + assert eq_span(code.Hz, z) # type: ignore[arg-type] + assert eq_span(np.vstack((code.Hx, code.Lx)), x) # type: ignore[arg-type] + + sp_circ_zero = heuristic_prep_circuit(code, zero_state=True) + circ_zero = sp_circ_zero.circ + x_zero, z_zero = get_stabs(circ_zero) + + if code.is_self_dual(): + assert np.array_equal(x, z_zero) + assert np.array_equal(z, x_zero) + else: + assert not np.array_equal(x, z_zero) + assert not np.array_equal(z, x_zero) + + def test_optimal_steane_verification_circuit(steane_code_sp: StatePrepCircuit) -> None: """Test that the optimal verification circuit for the Steane code is correct.""" circ = steane_code_sp @@ -170,6 +266,66 @@ def test_heuristic_steane_verification_circuit(steane_code_sp: StatePrepCircuit) assert circ_ver.depth() == np.sum(ver_stabs) + circ.circ.depth() + 1 # 1 for the measurement +def test_optimal_tetrahedral_verification_circuit(tetrahedral_code_sp: StatePrepCircuit) -> None: + """Test the optimal verification circuit for the tetrahedral code is correct. + + The tetrahedral code has an x-distance of 7. We expect that the verification only checks for a single propagated error since the tetrahedral code has a distance of 3. + """ + circ = tetrahedral_code_sp + + ver_stabs_layers = gate_optimal_verification_stabilizers(circ, x_errors=True, max_ancillas=1, max_timeout=2) + + assert len(ver_stabs_layers) == 1 # 1 layer of verification measurements + + ver_stabs = ver_stabs_layers[0] + assert len(ver_stabs) == 1 # 1 Ancilla measurement + assert np.sum(ver_stabs[0]) == 3 # 3 CNOTs + z_gens = circ.z_checks + + for stab in ver_stabs: + assert in_span(z_gens, stab) + + errors = circ.compute_fault_set(1) + non_detected = np.where(np.all(ver_stabs @ errors.T % 2 == 0, axis=1))[0] + assert len(non_detected) == 0 + + # Check that circuit is correct + circ_ver = gate_optimal_verification_circuit(circ, max_ancillas=1, max_timeout=2) + assert circ_ver.num_qubits == circ.num_qubits + 1 + assert circ_ver.num_nonlocal_gates() == np.sum(ver_stabs) + circ.circ.num_nonlocal_gates() + assert circ_ver.depth() == np.sum(ver_stabs) + circ.circ.depth() + 1 # 1 for the measurement + + +def test_heuristic_tetrahedral_verification_circuit(tetrahedral_code_sp: StatePrepCircuit) -> None: + """Test the optimal verification circuit for the tetrahedral code is correct. + + The tetrahedral code has an x-distance of 7. We expect that the verification only checks for a single propagated error since the tetrahedral code has a distance of 3. + """ + circ = tetrahedral_code_sp + + ver_stabs_layers = heuristic_verification_stabilizers(circ, x_errors=True) + + assert len(ver_stabs_layers) == 1 # 1 layer of verification measurements + + ver_stabs = ver_stabs_layers[0] + assert len(ver_stabs) == 1 # 1 Ancilla measurement + assert np.sum(ver_stabs[0]) == 3 # 3 CNOTs + z_gens = circ.z_checks + + for stab in ver_stabs: + assert in_span(z_gens, stab) + + errors = circ.compute_fault_set(1) + non_detected = np.where(np.all(ver_stabs @ errors.T % 2 == 0, axis=1))[0] + assert len(non_detected) == 0 + + # Check that circuit is correct + circ_ver = heuristic_verification_circuit(circ) + assert circ_ver.num_qubits == circ.num_qubits + 1 + assert circ_ver.num_nonlocal_gates() == np.sum(ver_stabs) + circ.circ.num_nonlocal_gates() + assert circ_ver.depth() == np.sum(ver_stabs) + circ.circ.depth() + 1 # 1 for the measurement + + def test_not_full_ft_opt_cc5(color_code_d5_sp: StatePrepCircuit) -> None: """Test that the optimal verification is also correct for higher distance.