From eaa2eee67ebdcda878539ca20b474f57d52ce784 Mon Sep 17 00:00:00 2001 From: Tom Peham Date: Mon, 1 Jul 2024 14:36:41 +0200 Subject: [PATCH] Appease the mighty mypy. --- src/mqt/qecc/ft_stateprep/__init__.py | 4 +- src/mqt/qecc/ft_stateprep/simulation.py | 14 ++- src/mqt/qecc/ft_stateprep/state_prep.py | 116 ++++++++++++++------ test/python/ft_stateprep/test_simulation.py | 39 ++++--- test/python/ft_stateprep/test_stateprep.py | 20 ++++ 5 files changed, 141 insertions(+), 52 deletions(-) diff --git a/src/mqt/qecc/ft_stateprep/__init__.py b/src/mqt/qecc/ft_stateprep/__init__.py index 7ffe9a81..3e597881 100644 --- a/src/mqt/qecc/ft_stateprep/__init__.py +++ b/src/mqt/qecc/ft_stateprep/__init__.py @@ -12,7 +12,7 @@ heuristic_prep_circuit, heuristic_verification_circuit, heuristic_verification_stabilizers, - naive_verification_circuit + naive_verification_circuit, ) __all__ = [ @@ -26,5 +26,5 @@ "heuristic_prep_circuit", "heuristic_verification_circuit", "heuristic_verification_stabilizers", - "naive_verification_circuit" + "naive_verification_circuit", ] diff --git a/src/mqt/qecc/ft_stateprep/simulation.py b/src/mqt/qecc/ft_stateprep/simulation.py index c8250314..fc25ecf9 100644 --- a/src/mqt/qecc/ft_stateprep/simulation.py +++ b/src/mqt/qecc/ft_stateprep/simulation.py @@ -10,6 +10,8 @@ import stim from qiskit.converters import circuit_to_dag, dag_to_circuit +from ..code import InvalidCSSCodeError + if TYPE_CHECKING: # pragma: no cover import numpy.typing as npt from qiskit import QuantumCircuit @@ -29,6 +31,10 @@ def __init__(self, state_prep_circ: QuantumCircuit, code: CSSCode, p: float, zer p: The error rate. zero_state: Whether thezero state is prepared or nor. """ + if code.Hx is None or code.Hz is None: + msg = "The code must have both X and Z checks." + raise InvalidCSSCodeError(msg) + self.circ = state_prep_circ self.num_qubits = state_prep_circ.num_qubits self.code = code @@ -102,7 +108,7 @@ def idle_error(used_qubits: list[int]) -> None: used_qubits = [] # type: list[int] targets = set() - measured = defaultdict(int) + measured = defaultdict(int) # type: defaultdict[int, int] for layer in layers: layer_circ = dag_to_circuit(layer["graph"]) @@ -161,6 +167,9 @@ def measure_stabilizers(self) -> stim.Circuit: An ancilla is used for each measurement. """ + assert self.code.Hx is not None + assert self.code.Hz is not None + for check in self.code.Hx: supp = _support(check) anc = self.stim_circ.num_qubits @@ -315,6 +324,7 @@ def generate_x_lut(self) -> None: if len(self.x_lut) != 0: return + assert self.code.Hz is not None, "The code does not have a Z stabilizer matrix." self.x_lut = LutDecoder._generate_lut(self.code.Hz) if self.code.is_self_dual(): self.z_lut = self.x_lut @@ -323,6 +333,8 @@ def generate_z_lut(self) -> None: """Generate the lookup table for the Z errors.""" if len(self.z_lut) != 0: return + + assert self.code.Hx is not None, "The code does not have an X stabilizer matrix." self.z_lut = LutDecoder._generate_lut(self.code.Hx) if self.code.is_self_dual(): self.z_lut = self.x_lut diff --git a/src/mqt/qecc/ft_stateprep/state_prep.py b/src/mqt/qecc/ft_stateprep/state_prep.py index a96a8999..a37df684 100644 --- a/src/mqt/qecc/ft_stateprep/state_prep.py +++ b/src/mqt/qecc/ft_stateprep/state_prep.py @@ -14,8 +14,11 @@ from qiskit.converters import circuit_to_dag from qiskit.dagcircuit import DAGOutNode +from ..code import InvalidCSSCodeError + logger = logging.getLogger(__name__) + if TYPE_CHECKING: # pragma: no cover from collections.abc import Callable @@ -40,8 +43,14 @@ def __init__(self, circ: QuantumCircuit, code: CSSCode, zero_state: bool = True) self.circ = circ self.code = code self.zero_state = zero_state + + if code.Hx is None or code.Hz is None: + msg = "The CSS code must have both X and Z checks." + raise InvalidCSSCodeError(msg) + self.x_checks = code.Hx.copy() if zero_state else np.vstack((code.Lx.copy(), code.Hx.copy())) self.z_checks = code.Hz.copy() if not zero_state else np.vstack((code.Lz.copy(), code.Hz.copy())) + self.num_qubits = circ.num_qubits self.max_errors = (code.distance - 1) // 2 self.x_fault_sets = [None for _ in range(self.max_errors + 1)] # type: list[npt.NDArray[np.int8] | None] @@ -92,7 +101,7 @@ def compute_fault_set( non_propagated_single_errors = np.eye(self.num_qubits, dtype=np.int8) # type: npt.NDArray[np.int8] self.x_fault_sets_unreduced[1] = np.vstack((faults, non_propagated_single_errors)) elif not x_errors and self.z_fault_sets[1] is None: - non_propagated_single_errors = np.eye(self.num_qubits, dtype=np.int8) # type: npt.NDArray[np.int8] + non_propagated_single_errors = np.eye(self.num_qubits, dtype=np.int8) self.z_fault_sets_unreduced[1] = np.vstack((faults, non_propagated_single_errors)) else: logging.info(f"Computing fault set for {num_errors} errors.") @@ -129,8 +138,10 @@ def compute_fault_set( self.z_fault_sets[num_errors] = faults return faults - def combine_faults(self, additional_faults: np.ndarray, x_errors: bool = True) -> npt.NDArray[np.int8]: - """Combine fault sets of circuit with additional indpendent faults. + def combine_faults( + self, additional_faults: npt.NDArray[np.int8], x_errors: bool = True + ) -> list[npt.NDArray[np.int8] | None]: + """Combine fault sets of circuit with additional independent faults. Args: additional_faults: The additional faults to combine with the fault set of the circuit. @@ -139,18 +150,22 @@ def combine_faults(self, additional_faults: np.ndarray, x_errors: bool = True) - self.compute_fault_sets() fault_sets_unreduced = self.x_fault_sets_unreduced.copy() if x_errors else self.z_fault_sets_unreduced.copy() + assert fault_sets_unreduced[1] is not None fault_sets_unreduced[1] = np.vstack((fault_sets_unreduced[1], additional_faults)) for i in range(1, self.max_errors): uncombined = fault_sets_unreduced[i] + assert uncombined is not None combined = (uncombined[:, np.newaxis, :] + additional_faults).reshape(-1, self.num_qubits) % 2 - fault_sets_unreduced[i + 1] = np.vstack((fault_sets_unreduced[i + 1], combined)) + next_faults = fault_sets_unreduced[i + 1] + assert next_faults is not None + fault_sets_unreduced[i + 1] = np.vstack((next_faults, combined)) fault_sets = [None for _ in range(self.max_errors + 1)] # type: list[None | npt.NDArray[np.int8]] stabs = self.x_checks if x_errors else self.z_checks for num_errors in range(1, self.max_errors + 1): - fault_sets[num_errors] = _remove_trivial_faults( - fault_sets_unreduced[num_errors], stabs, self.code, x_errors, num_errors - ) + fs = fault_sets_unreduced[num_errors] + assert fs is not None + fault_sets[num_errors] = _remove_trivial_faults(fs, stabs, self.code, x_errors, num_errors) return fault_sets @@ -163,6 +178,9 @@ def heuristic_prep_circuit(code: CSSCode, optimize_depth: bool = True, zero_stat zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis. """ logging.info("Starting heuristic state preparation.") + if code.Hx is None or code.Hz is None: + msg = "The code must have both X and Z stabilizers defined." + raise InvalidCSSCodeError(msg) checks = code.Hx.copy() if zero_state else code.Hz.copy() rank = mod2.rank(checks) @@ -357,6 +375,9 @@ def _optimal_circuit( min_timeout: minimum timeout to start with max_timeout: maximum timeout to reach """ + if code.Hx is None or code.Hz is None: + msg = "Code must have both X and Z stabilizers defined." + raise ValueError(msg) checks = code.Hx if zero_state else code.Hz def fun(param: int) -> QuantumCircuit | None: @@ -541,6 +562,7 @@ def gate_optimal_verification_stabilizers( min_timeout: The minimum time to allow each search to run for. max_timeout: The maximum time to allow each search to run for. max_ancillas: The maximum number of ancillas to allow in each layer verification circuit. + additional_faults: Faults to verify in addition to the faults propagating in the state preparation circuit. Returns: A list of stabilizers to verify the state preparation circuit. @@ -563,6 +585,7 @@ def gate_optimal_verification_stabilizers( for num_errors in range(1, max_errors + 1): logging.info(f"Finding verification stabilizers for {num_errors} errors") faults = fault_sets[num_errors] + assert faults is not None if len(faults) == 0: logging.info(f"No non-trivial faults for {num_errors} errors") @@ -648,10 +671,12 @@ def search_anc(num_anc: int) -> list[npt.NDArray[np.int8]] | None: def _verification_circuit( sp_circ: StatePrepCircuit, - verification_stabs_fun: Callable[[StatePrepCircuit, bool, npt.NDArray[np.int8] | None], list[npt.NDArray[np.int8]]], + verification_stabs_fun: Callable[ + [StatePrepCircuit, bool, npt.NDArray[np.int8] | None], list[list[npt.NDArray[np.int8]]] + ], ) -> QuantumCircuit: logging.info("Finding verification stabilizers for the state preparation circuit") - layers_1 = verification_stabs_fun(sp_circ, sp_circ.zero_state) + layers_1 = verification_stabs_fun(sp_circ, sp_circ.zero_state) # type: ignore[call-arg] measurements_1 = [measurement for layer in layers_1 for measurement in layer] additional_errors = _hook_errors(measurements_1) layers_2 = verification_stabs_fun(sp_circ, not sp_circ.zero_state, additional_errors) @@ -677,20 +702,14 @@ def gate_optimal_verification_circuit( max_ancillas: The maximum number of ancillas to allow in each layer verification circuit. """ - def verification_stabs_fun(sp_circ, zero_state, additional_errors=None): + def verification_stabs_fun( + sp_circ: StatePrepCircuit, zero_state: bool, additional_errors: npt.NDArray[np.int8] | None = None + ) -> list[list[npt.NDArray[np.int8]]]: return gate_optimal_verification_stabilizers( sp_circ, zero_state, min_timeout, max_timeout, max_ancillas, additional_errors ) - # logging.info("Finding optimal verification stabilizers for X errors") - # x_layers = gate_optimal_verification_stabilizers(sp_circ, True, min_timeout, max_timeout, max_ancillas) - - # z_layers = gate_optimal_verification_stabilizers(sp_circ, False, min_timeout, max_timeout, max_ancillas) - - # z_measurements = [measurement for layer in x_layers for measurement in layer] - # x_measurements = [measurement for layer in z_layers for measurement in layer] return _verification_circuit(sp_circ, verification_stabs_fun) - # return _measure_ft_stabs(sp_circ, x_measurements, z_measurements) def heuristic_verification_circuit( @@ -705,16 +724,15 @@ def heuristic_verification_circuit( max_covering_sets: The maximum number of covering sets to consider. find_coset_leaders: Whether to find coset leaders for the found measurements. This is done using SAT solvers so it can be slow. """ - # x_layers = heuristic_verification_stabilizers(sp_circ, True, max_covering_sets, find_coset_leaders) - # z_layers = heuristic_verification_stabilizers(sp_circ, False, max_covering_sets, find_coset_leaders) - # z_measurements = [measurement for layer in x_layers for measurement in layer] - # x_measurements = [measurement for layer in z_layers for measurement in layer] - def verification_stabs_fun(sp_circ, zero_state, additional_errors=None): - return heuristic_verification_stabilizers(sp_circ, zero_state, max_covering_sets, find_coset_leaders) + def verification_stabs_fun( + sp_circ: StatePrepCircuit, zero_state: bool, additional_errors: npt.NDArray[np.int8] | None = None + ) -> list[list[npt.NDArray[np.int8]]]: + return heuristic_verification_stabilizers( + sp_circ, zero_state, max_covering_sets, find_coset_leaders, additional_errors + ) return _verification_circuit(sp_circ, verification_stabs_fun) - # return _measure_ft_stabs(sp_circ, x_measurements, z_measurements) def heuristic_verification_stabilizers( @@ -731,6 +749,7 @@ def heuristic_verification_stabilizers( x_errors: Whether to find verification stabilizers for X errors. If False, find for Z errors. max_covering_sets: The maximum number of covering sets to consider. find_coset_leaders: Whether to find coset leaders for the found measurements. This is done using SAT solvers so it can be slow. + additional_faults: Faults to verify in addition to the faults propagating in the state preparation circuit. """ logging.info("Finding verification stabilizers using heuristic method") max_errors = (sp_circ.code.distance - 1) // 2 @@ -747,6 +766,7 @@ def heuristic_verification_stabilizers( for num_errors in range(1, max_errors + 1): logging.info(f"Finding verification stabilizers for {num_errors} errors") faults = fault_sets[num_errors] + assert faults is not None logging.info(f"There are {len(faults)} faults") if len(faults) == 0: layers[num_errors - 1] = [] @@ -902,7 +922,7 @@ def _vars_to_stab( def verification_stabilizers( sp_circ: StatePrepCircuit, - fault_set: list[npt.NDArray[np.int8]], + fault_set: npt.NDArray[np.int8], num_anc: int, num_cnots: int, x_errors: bool = True, @@ -911,6 +931,7 @@ def verification_stabilizers( Args: sp_circ: The state preparation circuit. + fault_set: The set of errors to verify. num_anc: The maximum number of ancilla qubits to use. num_cnots: The maximumg number of CNOT gates to use. num_errors: The number of errors occur in the state prep circuit. @@ -1107,7 +1128,7 @@ def _propagate_error(dag: DagCircuit, node: DAGNode, x_errors: bool = True) -> P def _remove_trivial_faults( faults: npt.NDArray[np.int8], stabs: npt.NDArray[np.int8], code: CSSCode, x_errors: bool, num_errors: int -): +) -> npt.NDArray[np.int8]: # remove trivial faults faults = faults.copy() logging.info("Removing trivial faults.") @@ -1156,8 +1177,12 @@ def _remove_stabilizer_equivalent_faults( return faults[indices] -def naive_verification_circuit(sp_circ: StatePrepCircuit): +def naive_verification_circuit(sp_circ: StatePrepCircuit) -> QuantumCircuit: """Naive verification circuit for a state preparation circuit.""" + if sp_circ.code.Hx is None or sp_circ.code.Hz is None: + msg = "Code must have stabilizers defined." + raise ValueError(msg) + z_measurements = list(sp_circ.code.Hx) x_measurements = list(sp_circ.code.Hz) reps = (sp_circ.code.distance - 1) // 2 @@ -1206,7 +1231,11 @@ def _flag_init(qc: QuantumCircuit, flag: AncillaQubit, z_measurement: bool) -> N def _measure_stab_unflagged( - qc: QuantumCircuit, stab: list[Qubit], ancilla: AncillaQubit, measurement_bit: ClBit, z_measurement=True + qc: QuantumCircuit, + stab: list[Qubit] | npt.NDArray[np.int_], + ancilla: AncillaQubit, + measurement_bit: ClBit, + z_measurement: bool = True, ) -> None: if not z_measurement: qc.h(ancilla) @@ -1218,7 +1247,11 @@ def _measure_stab_unflagged( def measure_flagged( - qc: QuantumCircuit, stab: list[Qubit], ancilla: AncillaQubit, measurement_bit: ClBit, z_measurement=True + qc: QuantumCircuit, + stab: list[Qubit] | npt.NDArray[np.int_], + ancilla: AncillaQubit, + measurement_bit: ClBit, + z_measurement: bool = True, ) -> None: """Measure a w-flagged stabilizer with the general scheme. @@ -1228,6 +1261,7 @@ def measure_flagged( qc: The quantum circuit to add the measurement to. stab: The qubits to measure. ancilla: The ancilla qubit to use for the measurement. + measurement_bit: The classical bit to store the measurement result of the ancilla. z_measurement: Whether to measure an X (False) or Z (True) stabilizer. """ w = len(stab) @@ -1290,7 +1324,11 @@ def measure_flagged( def measure_flagged_4( - qc: QuantumCircuit, stab: list[Qubit], ancilla: AncillaQubit, measurement_bit: ClBit, z_measurement=True + qc: QuantumCircuit, + stab: list[Qubit] | npt.NDArray[np.int_], + ancilla: AncillaQubit, + measurement_bit: ClBit, + z_measurement: bool = True, ) -> None: """Measure a 4-flagged stabilizer using an optimized scheme.""" assert len(stab) == 4 @@ -1323,8 +1361,13 @@ def measure_flagged_4( def measure_flagged_6( - qc: QuantumCircuit, stab: list[Qubit], ancilla: AncillaQubit, measurement_bit: ClBit, z_measurement=True + qc: QuantumCircuit, + stab: list[Qubit] | npt.NDArray[np.int_], + ancilla: AncillaQubit, + measurement_bit: ClBit, + z_measurement: bool = True, ) -> None: + """Measure a 6-flagged stabilizer using an optimized scheme.""" assert len(stab) == 6 flag = AncillaRegister(2) meas = ClassicalRegister(2) @@ -1364,8 +1407,13 @@ def measure_flagged_6( def measure_flagged_8( - qc: QuantumCircuit, stab: list[Qubit], ancilla: AncillaQubit, measurement_bit: ClBit, z_measurement=True + qc: QuantumCircuit, + stab: list[Qubit] | npt.NDArray[np.int_], + ancilla: AncillaQubit, + measurement_bit: ClBit, + z_measurement: bool = True, ) -> None: + """Measure an 8-flagged stabilizer using an optimized scheme.""" assert len(stab) == 8 flag = AncillaRegister(3) meas = ClassicalRegister(3) @@ -1412,7 +1460,7 @@ def measure_flagged_8( qc.measure(ancilla, measurement_bit) -def _hook_errors(stabs: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]: +def _hook_errors(stabs: list[npt.NDArray[np.int8]]) -> npt.NDArray[np.int8]: """Assuming CNOTs are executed in ascending order of qubit index, this function gives all the hook errors of the given stabilizer measurements.""" errors = [] for stab in stabs: diff --git a/test/python/ft_stateprep/test_simulation.py b/test/python/ft_stateprep/test_simulation.py index 4da13168..723f26b0 100644 --- a/test/python/ft_stateprep/test_simulation.py +++ b/test/python/ft_stateprep/test_simulation.py @@ -1,16 +1,18 @@ +"""Test the simulation of fault-tolerant state preparation circuits.""" + from __future__ import annotations from typing import TYPE_CHECKING -import pytest import numpy as np +import pytest from mqt.qecc import CSSCode from mqt.qecc.ft_stateprep import ( + LutDecoder, NoisyNDFTStatePrepSimulator, heuristic_prep_circuit, heuristic_verification_circuit, - LutDecoder, ) if TYPE_CHECKING: # pragma: no cover @@ -19,22 +21,28 @@ @pytest.fixture() def steane_code() -> CSSCode: + """Return the Steane code.""" return CSSCode.from_code_name("steane") @pytest.fixture() def non_ft_steane_zero(steane_code: CSSCode) -> QuantumCircuit: + """Return a non fault-tolerant Steane code state preparation circuit.""" return heuristic_prep_circuit(steane_code).circ @pytest.fixture() def ft_steane_circ(steane_code: CSSCode) -> QuantumCircuit: + """Return a fault-tolerant Steane code state preparation circuit.""" circ = heuristic_prep_circuit(steane_code) return heuristic_verification_circuit(circ) -def test_lut(steane_code: CSSCode): +def test_lut(steane_code: CSSCode) -> None: """Test the LutDecoder class.""" + assert steane_code.Hx is not None, "Steane code does not have X stabilizers." + assert steane_code.Hz is not None, "Steane code does not have Z stabilizers." + lut = LutDecoder(steane_code, init_luts=False) assert len(lut.x_lut) == 0 @@ -42,20 +50,20 @@ def test_lut(steane_code: CSSCode): lut.generate_x_lut() lut.generate_z_lut() - + assert len(lut.x_lut) != 0 assert lut.x_lut is lut.z_lut # Code is self dual so luts should be the same - error_1 = np.zeros(steane_code.n, dtype=np.int8) + error_1 = np.zeros(steane_code.n, dtype=np.int8) # type: ignore[var-annotated] error_1[0] = 1 - + error_w1 = (steane_code.Hx[0] + error_1) % 2 syndrome_1 = steane_code.get_x_syndrome(error_w1) estimate_1 = lut.decode_x(syndrome_1.astype(np.int8)) assert steane_code.stabilizer_eq_x_error(estimate_1, error_1) assert steane_code.stabilizer_eq_z_error(estimate_1, error_1) - - error_2 = np.zeros(steane_code.n, dtype=np.int8) + + error_2 = np.zeros(steane_code.n, dtype=np.int8) # type: ignore[var-annotated] error_2[0] = 1 error_2[1] = 1 error_w2 = (steane_code.Hx[0] + error_2) % 2 @@ -65,18 +73,19 @@ def test_lut(steane_code: CSSCode): # Weight 2 error should have be estimated to be weight 1 assert not steane_code.stabilizer_eq_x_error(estimate_2, error_2) assert np.sum(estimate_2) == 1 - - error_3 = np.ones((steane_code.n), dtype=np.int8) + + error_3 = np.ones((steane_code.n), dtype=np.int8) # type: ignore[var-annotated] error_w3 = (steane_code.Hx[0] + error_3) % 2 syndrome_3 = steane_code.get_x_syndrome(error_w3) estimate_3 = lut.decode_x(syndrome_3.astype(np.int8)) # Weight 3 error should have be estimated to be weight 0 assert not steane_code.stabilizer_eq_x_error(estimate_3, error_3) - assert steane_code.stabilizer_eq_x_error(estimate_3, np.zeros(steane_code.n)) + assert steane_code.stabilizer_eq_x_error(estimate_3, np.zeros(steane_code.n, dtype=np.int8)) assert np.sum(estimate_3) == 0 - -def test_non_ft_sim(steane_code: CSSCode, non_ft_steane_zero: QuantumCircuit): + +def test_non_ft_sim(steane_code: CSSCode, non_ft_steane_zero: QuantumCircuit) -> None: + """Test the simulation of a non fault-tolerant state preparation circuit.""" tol = 5e-4 p = 1e-3 lower = 1e-4 @@ -86,7 +95,8 @@ def test_non_ft_sim(steane_code: CSSCode, non_ft_steane_zero: QuantumCircuit): assert p_l - tol > lower -def test_ft_sim(steane_code: CSSCode, ft_steane_circ: QuantumCircuit): +def test_ft_sim(steane_code: CSSCode, ft_steane_circ: QuantumCircuit) -> None: + """Test the simulation of a fault-tolerant state preparation circuit.""" tol = 5e-4 p = 1e-3 lower = 1e-4 @@ -94,4 +104,3 @@ def test_ft_sim(steane_code: CSSCode, ft_steane_circ: QuantumCircuit): p_l, _, _, _ = simulator.logical_error_rate() assert p_l - tol < lower - diff --git a/test/python/ft_stateprep/test_stateprep.py b/test/python/ft_stateprep/test_stateprep.py index df428759..a9bb2640 100644 --- a/test/python/ft_stateprep/test_stateprep.py +++ b/test/python/ft_stateprep/test_stateprep.py @@ -60,6 +60,10 @@ def get_stabs(qc: QuantumCircuit) -> tuple[npt.NDArray[np.int_], npt.NDArray[np. def test_heuristic_prep_consistent(code_name: str) -> None: """Check that heuristic_prep_circuit returns a valid circuit with the correct stabilizers.""" code = CSSCode.from_code_name(code_name) + + assert code.Hx is not None, f"Code {code_name} does not have X stabilizers." + assert code.Hz is not None, f"Code {code_name} does not have Z stabilizers." + sp_circ = heuristic_prep_circuit(code) circ = sp_circ.circ max_cnots = np.sum(code.Hx) + np.sum(code.Hz) @@ -76,6 +80,10 @@ def test_heuristic_prep_consistent(code_name: str) -> None: def test_gate_optimal_prep_consistent(code_name: str) -> None: """Check that gate_optimal_prep_circuit returns a valid circuit with the correct stabilizers.""" code = CSSCode.from_code_name(code_name) + + assert code.Hx is not None, f"Code {code_name} does not have X stabilizers." + assert code.Hz is not None, f"Code {code_name} does not have Z stabilizers." + sp_circ = gate_optimal_prep_circuit(code, max_timeout=2) assert sp_circ is not None circ = sp_circ.circ @@ -93,6 +101,10 @@ def test_gate_optimal_prep_consistent(code_name: str) -> None: def test_depth_optimal_prep_consistent(code_name: str) -> None: """Check that depth_optimal_prep_circuit returns a valid circuit with the correct stabilizers.""" code = CSSCode.from_code_name(code_name) + + assert code.Hx is not None, f"Code {code_name} does not have X stabilizers." + assert code.Hz is not None, f"Code {code_name} does not have Z stabilizers." + sp_circ = gate_optimal_prep_circuit(code, max_timeout=2) assert sp_circ is not None circ = sp_circ.circ @@ -109,6 +121,10 @@ def test_depth_optimal_prep_consistent(code_name: str) -> None: def test_optimal_steane_verification_circuit(steane_code: CSSCode) -> None: """Test that the optimal verification circuit for the Steane code is correct.""" circ = heuristic_prep_circuit(steane_code) + + assert steane_code.Hx is not None, "Steane code does not have X stabilizers." + assert steane_code.Hz is not None, "Steane code does not have Z stabilizers." + ver_stabs_layers = gate_optimal_verification_stabilizers(circ, x_errors=True, max_timeout=2) assert len(ver_stabs_layers) == 1 # 1 Ancilla measurement @@ -135,6 +151,10 @@ def test_optimal_steane_verification_circuit(steane_code: CSSCode) -> None: def test_heuristic_steane_verification_circuit(steane_code: CSSCode) -> None: """Test that the optimal verification circuit for the Steane code is correct.""" circ = heuristic_prep_circuit(steane_code) + + assert steane_code.Hx is not None, "Steane code does not have X stabilizers." + assert steane_code.Hz is not None, "Steane code does not have Z stabilizers." + ver_stabs_layers = heuristic_verification_stabilizers(circ, x_errors=True) assert len(ver_stabs_layers) == 1 # 1 Ancilla measurement