Skip to content

Commit

Permalink
Appease the mighty mypy.
Browse files Browse the repository at this point in the history
  • Loading branch information
pehamTom committed Jul 1, 2024
1 parent 7a4113b commit eaa2eee
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 52 deletions.
4 changes: 2 additions & 2 deletions src/mqt/qecc/ft_stateprep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
heuristic_prep_circuit,
heuristic_verification_circuit,
heuristic_verification_stabilizers,
naive_verification_circuit
naive_verification_circuit,
)

__all__ = [
Expand All @@ -26,5 +26,5 @@
"heuristic_prep_circuit",
"heuristic_verification_circuit",
"heuristic_verification_stabilizers",
"naive_verification_circuit"
"naive_verification_circuit",
]
14 changes: 13 additions & 1 deletion src/mqt/qecc/ft_stateprep/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import stim
from qiskit.converters import circuit_to_dag, dag_to_circuit

from ..code import InvalidCSSCodeError

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt
from qiskit import QuantumCircuit
Expand All @@ -29,6 +31,10 @@ def __init__(self, state_prep_circ: QuantumCircuit, code: CSSCode, p: float, zer
p: The error rate.
zero_state: Whether thezero state is prepared or nor.
"""
if code.Hx is None or code.Hz is None:
msg = "The code must have both X and Z checks."
raise InvalidCSSCodeError(msg)

self.circ = state_prep_circ
self.num_qubits = state_prep_circ.num_qubits
self.code = code
Expand Down Expand Up @@ -102,7 +108,7 @@ def idle_error(used_qubits: list[int]) -> None:
used_qubits = [] # type: list[int]

targets = set()
measured = defaultdict(int)
measured = defaultdict(int) # type: defaultdict[int, int]
for layer in layers:
layer_circ = dag_to_circuit(layer["graph"])

Expand Down Expand Up @@ -161,6 +167,9 @@ def measure_stabilizers(self) -> stim.Circuit:
An ancilla is used for each measurement.
"""
assert self.code.Hx is not None
assert self.code.Hz is not None

for check in self.code.Hx:
supp = _support(check)
anc = self.stim_circ.num_qubits
Expand Down Expand Up @@ -315,6 +324,7 @@ def generate_x_lut(self) -> None:
if len(self.x_lut) != 0:
return

assert self.code.Hz is not None, "The code does not have a Z stabilizer matrix."
self.x_lut = LutDecoder._generate_lut(self.code.Hz)
if self.code.is_self_dual():
self.z_lut = self.x_lut
Expand All @@ -323,6 +333,8 @@ def generate_z_lut(self) -> None:
"""Generate the lookup table for the Z errors."""
if len(self.z_lut) != 0:
return

assert self.code.Hx is not None, "The code does not have an X stabilizer matrix."
self.z_lut = LutDecoder._generate_lut(self.code.Hx)
if self.code.is_self_dual():
self.z_lut = self.x_lut
Expand Down
116 changes: 82 additions & 34 deletions src/mqt/qecc/ft_stateprep/state_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
from qiskit.converters import circuit_to_dag
from qiskit.dagcircuit import DAGOutNode

from ..code import InvalidCSSCodeError

logger = logging.getLogger(__name__)


if TYPE_CHECKING: # pragma: no cover
from collections.abc import Callable

Expand All @@ -40,8 +43,14 @@ def __init__(self, circ: QuantumCircuit, code: CSSCode, zero_state: bool = True)
self.circ = circ
self.code = code
self.zero_state = zero_state

if code.Hx is None or code.Hz is None:
msg = "The CSS code must have both X and Z checks."
raise InvalidCSSCodeError(msg)

self.x_checks = code.Hx.copy() if zero_state else np.vstack((code.Lx.copy(), code.Hx.copy()))
self.z_checks = code.Hz.copy() if not zero_state else np.vstack((code.Lz.copy(), code.Hz.copy()))

self.num_qubits = circ.num_qubits
self.max_errors = (code.distance - 1) // 2
self.x_fault_sets = [None for _ in range(self.max_errors + 1)] # type: list[npt.NDArray[np.int8] | None]
Expand Down Expand Up @@ -92,7 +101,7 @@ def compute_fault_set(
non_propagated_single_errors = np.eye(self.num_qubits, dtype=np.int8) # type: npt.NDArray[np.int8]
self.x_fault_sets_unreduced[1] = np.vstack((faults, non_propagated_single_errors))
elif not x_errors and self.z_fault_sets[1] is None:
non_propagated_single_errors = np.eye(self.num_qubits, dtype=np.int8) # type: npt.NDArray[np.int8]
non_propagated_single_errors = np.eye(self.num_qubits, dtype=np.int8)
self.z_fault_sets_unreduced[1] = np.vstack((faults, non_propagated_single_errors))
else:
logging.info(f"Computing fault set for {num_errors} errors.")
Expand Down Expand Up @@ -129,8 +138,10 @@ def compute_fault_set(
self.z_fault_sets[num_errors] = faults
return faults

def combine_faults(self, additional_faults: np.ndarray, x_errors: bool = True) -> npt.NDArray[np.int8]:
"""Combine fault sets of circuit with additional indpendent faults.
def combine_faults(
self, additional_faults: npt.NDArray[np.int8], x_errors: bool = True
) -> list[npt.NDArray[np.int8] | None]:
"""Combine fault sets of circuit with additional independent faults.
Args:
additional_faults: The additional faults to combine with the fault set of the circuit.
Expand All @@ -139,18 +150,22 @@ def combine_faults(self, additional_faults: np.ndarray, x_errors: bool = True) -
self.compute_fault_sets()

fault_sets_unreduced = self.x_fault_sets_unreduced.copy() if x_errors else self.z_fault_sets_unreduced.copy()
assert fault_sets_unreduced[1] is not None
fault_sets_unreduced[1] = np.vstack((fault_sets_unreduced[1], additional_faults))

for i in range(1, self.max_errors):
uncombined = fault_sets_unreduced[i]
assert uncombined is not None
combined = (uncombined[:, np.newaxis, :] + additional_faults).reshape(-1, self.num_qubits) % 2
fault_sets_unreduced[i + 1] = np.vstack((fault_sets_unreduced[i + 1], combined))
next_faults = fault_sets_unreduced[i + 1]
assert next_faults is not None
fault_sets_unreduced[i + 1] = np.vstack((next_faults, combined))
fault_sets = [None for _ in range(self.max_errors + 1)] # type: list[None | npt.NDArray[np.int8]]
stabs = self.x_checks if x_errors else self.z_checks
for num_errors in range(1, self.max_errors + 1):
fault_sets[num_errors] = _remove_trivial_faults(
fault_sets_unreduced[num_errors], stabs, self.code, x_errors, num_errors
)
fs = fault_sets_unreduced[num_errors]
assert fs is not None
fault_sets[num_errors] = _remove_trivial_faults(fs, stabs, self.code, x_errors, num_errors)
return fault_sets


Expand All @@ -163,6 +178,9 @@ def heuristic_prep_circuit(code: CSSCode, optimize_depth: bool = True, zero_stat
zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis.
"""
logging.info("Starting heuristic state preparation.")
if code.Hx is None or code.Hz is None:
msg = "The code must have both X and Z stabilizers defined."
raise InvalidCSSCodeError(msg)
checks = code.Hx.copy() if zero_state else code.Hz.copy()
rank = mod2.rank(checks)

Expand Down Expand Up @@ -357,6 +375,9 @@ def _optimal_circuit(
min_timeout: minimum timeout to start with
max_timeout: maximum timeout to reach
"""
if code.Hx is None or code.Hz is None:
msg = "Code must have both X and Z stabilizers defined."
raise ValueError(msg)
checks = code.Hx if zero_state else code.Hz

def fun(param: int) -> QuantumCircuit | None:
Expand Down Expand Up @@ -541,6 +562,7 @@ def gate_optimal_verification_stabilizers(
min_timeout: The minimum time to allow each search to run for.
max_timeout: The maximum time to allow each search to run for.
max_ancillas: The maximum number of ancillas to allow in each layer verification circuit.
additional_faults: Faults to verify in addition to the faults propagating in the state preparation circuit.
Returns:
A list of stabilizers to verify the state preparation circuit.
Expand All @@ -563,6 +585,7 @@ def gate_optimal_verification_stabilizers(
for num_errors in range(1, max_errors + 1):
logging.info(f"Finding verification stabilizers for {num_errors} errors")
faults = fault_sets[num_errors]
assert faults is not None

if len(faults) == 0:
logging.info(f"No non-trivial faults for {num_errors} errors")
Expand Down Expand Up @@ -648,10 +671,12 @@ def search_anc(num_anc: int) -> list[npt.NDArray[np.int8]] | None:

def _verification_circuit(
sp_circ: StatePrepCircuit,
verification_stabs_fun: Callable[[StatePrepCircuit, bool, npt.NDArray[np.int8] | None], list[npt.NDArray[np.int8]]],
verification_stabs_fun: Callable[
[StatePrepCircuit, bool, npt.NDArray[np.int8] | None], list[list[npt.NDArray[np.int8]]]
],
) -> QuantumCircuit:
logging.info("Finding verification stabilizers for the state preparation circuit")
layers_1 = verification_stabs_fun(sp_circ, sp_circ.zero_state)
layers_1 = verification_stabs_fun(sp_circ, sp_circ.zero_state) # type: ignore[call-arg]
measurements_1 = [measurement for layer in layers_1 for measurement in layer]
additional_errors = _hook_errors(measurements_1)
layers_2 = verification_stabs_fun(sp_circ, not sp_circ.zero_state, additional_errors)
Expand All @@ -677,20 +702,14 @@ def gate_optimal_verification_circuit(
max_ancillas: The maximum number of ancillas to allow in each layer verification circuit.
"""

def verification_stabs_fun(sp_circ, zero_state, additional_errors=None):
def verification_stabs_fun(
sp_circ: StatePrepCircuit, zero_state: bool, additional_errors: npt.NDArray[np.int8] | None = None
) -> list[list[npt.NDArray[np.int8]]]:
return gate_optimal_verification_stabilizers(
sp_circ, zero_state, min_timeout, max_timeout, max_ancillas, additional_errors
)

# logging.info("Finding optimal verification stabilizers for X errors")
# x_layers = gate_optimal_verification_stabilizers(sp_circ, True, min_timeout, max_timeout, max_ancillas)

# z_layers = gate_optimal_verification_stabilizers(sp_circ, False, min_timeout, max_timeout, max_ancillas)

# z_measurements = [measurement for layer in x_layers for measurement in layer]
# x_measurements = [measurement for layer in z_layers for measurement in layer]
return _verification_circuit(sp_circ, verification_stabs_fun)
# return _measure_ft_stabs(sp_circ, x_measurements, z_measurements)


def heuristic_verification_circuit(
Expand All @@ -705,16 +724,15 @@ def heuristic_verification_circuit(
max_covering_sets: The maximum number of covering sets to consider.
find_coset_leaders: Whether to find coset leaders for the found measurements. This is done using SAT solvers so it can be slow.
"""
# x_layers = heuristic_verification_stabilizers(sp_circ, True, max_covering_sets, find_coset_leaders)
# z_layers = heuristic_verification_stabilizers(sp_circ, False, max_covering_sets, find_coset_leaders)

# z_measurements = [measurement for layer in x_layers for measurement in layer]
# x_measurements = [measurement for layer in z_layers for measurement in layer]
def verification_stabs_fun(sp_circ, zero_state, additional_errors=None):
return heuristic_verification_stabilizers(sp_circ, zero_state, max_covering_sets, find_coset_leaders)
def verification_stabs_fun(
sp_circ: StatePrepCircuit, zero_state: bool, additional_errors: npt.NDArray[np.int8] | None = None
) -> list[list[npt.NDArray[np.int8]]]:
return heuristic_verification_stabilizers(
sp_circ, zero_state, max_covering_sets, find_coset_leaders, additional_errors
)

return _verification_circuit(sp_circ, verification_stabs_fun)
# return _measure_ft_stabs(sp_circ, x_measurements, z_measurements)


def heuristic_verification_stabilizers(
Expand All @@ -731,6 +749,7 @@ def heuristic_verification_stabilizers(
x_errors: Whether to find verification stabilizers for X errors. If False, find for Z errors.
max_covering_sets: The maximum number of covering sets to consider.
find_coset_leaders: Whether to find coset leaders for the found measurements. This is done using SAT solvers so it can be slow.
additional_faults: Faults to verify in addition to the faults propagating in the state preparation circuit.
"""
logging.info("Finding verification stabilizers using heuristic method")
max_errors = (sp_circ.code.distance - 1) // 2
Expand All @@ -747,6 +766,7 @@ def heuristic_verification_stabilizers(
for num_errors in range(1, max_errors + 1):
logging.info(f"Finding verification stabilizers for {num_errors} errors")
faults = fault_sets[num_errors]
assert faults is not None
logging.info(f"There are {len(faults)} faults")
if len(faults) == 0:
layers[num_errors - 1] = []
Expand Down Expand Up @@ -902,7 +922,7 @@ def _vars_to_stab(

def verification_stabilizers(
sp_circ: StatePrepCircuit,
fault_set: list[npt.NDArray[np.int8]],
fault_set: npt.NDArray[np.int8],
num_anc: int,
num_cnots: int,
x_errors: bool = True,
Expand All @@ -911,6 +931,7 @@ def verification_stabilizers(
Args:
sp_circ: The state preparation circuit.
fault_set: The set of errors to verify.
num_anc: The maximum number of ancilla qubits to use.
num_cnots: The maximumg number of CNOT gates to use.
num_errors: The number of errors occur in the state prep circuit.
Expand Down Expand Up @@ -1107,7 +1128,7 @@ def _propagate_error(dag: DagCircuit, node: DAGNode, x_errors: bool = True) -> P

def _remove_trivial_faults(
faults: npt.NDArray[np.int8], stabs: npt.NDArray[np.int8], code: CSSCode, x_errors: bool, num_errors: int
):
) -> npt.NDArray[np.int8]:
# remove trivial faults
faults = faults.copy()
logging.info("Removing trivial faults.")
Expand Down Expand Up @@ -1156,8 +1177,12 @@ def _remove_stabilizer_equivalent_faults(
return faults[indices]


def naive_verification_circuit(sp_circ: StatePrepCircuit):
def naive_verification_circuit(sp_circ: StatePrepCircuit) -> QuantumCircuit:
"""Naive verification circuit for a state preparation circuit."""
if sp_circ.code.Hx is None or sp_circ.code.Hz is None:
msg = "Code must have stabilizers defined."
raise ValueError(msg)

z_measurements = list(sp_circ.code.Hx)
x_measurements = list(sp_circ.code.Hz)
reps = (sp_circ.code.distance - 1) // 2
Expand Down Expand Up @@ -1206,7 +1231,11 @@ def _flag_init(qc: QuantumCircuit, flag: AncillaQubit, z_measurement: bool) -> N


def _measure_stab_unflagged(
qc: QuantumCircuit, stab: list[Qubit], ancilla: AncillaQubit, measurement_bit: ClBit, z_measurement=True
qc: QuantumCircuit,
stab: list[Qubit] | npt.NDArray[np.int_],
ancilla: AncillaQubit,
measurement_bit: ClBit,
z_measurement: bool = True,
) -> None:
if not z_measurement:
qc.h(ancilla)
Expand All @@ -1218,7 +1247,11 @@ def _measure_stab_unflagged(


def measure_flagged(
qc: QuantumCircuit, stab: list[Qubit], ancilla: AncillaQubit, measurement_bit: ClBit, z_measurement=True
qc: QuantumCircuit,
stab: list[Qubit] | npt.NDArray[np.int_],
ancilla: AncillaQubit,
measurement_bit: ClBit,
z_measurement: bool = True,
) -> None:
"""Measure a w-flagged stabilizer with the general scheme.
Expand All @@ -1228,6 +1261,7 @@ def measure_flagged(
qc: The quantum circuit to add the measurement to.
stab: The qubits to measure.
ancilla: The ancilla qubit to use for the measurement.
measurement_bit: The classical bit to store the measurement result of the ancilla.
z_measurement: Whether to measure an X (False) or Z (True) stabilizer.
"""
w = len(stab)
Expand Down Expand Up @@ -1290,7 +1324,11 @@ def measure_flagged(


def measure_flagged_4(
qc: QuantumCircuit, stab: list[Qubit], ancilla: AncillaQubit, measurement_bit: ClBit, z_measurement=True
qc: QuantumCircuit,
stab: list[Qubit] | npt.NDArray[np.int_],
ancilla: AncillaQubit,
measurement_bit: ClBit,
z_measurement: bool = True,
) -> None:
"""Measure a 4-flagged stabilizer using an optimized scheme."""
assert len(stab) == 4
Expand Down Expand Up @@ -1323,8 +1361,13 @@ def measure_flagged_4(


def measure_flagged_6(
qc: QuantumCircuit, stab: list[Qubit], ancilla: AncillaQubit, measurement_bit: ClBit, z_measurement=True
qc: QuantumCircuit,
stab: list[Qubit] | npt.NDArray[np.int_],
ancilla: AncillaQubit,
measurement_bit: ClBit,
z_measurement: bool = True,
) -> None:
"""Measure a 6-flagged stabilizer using an optimized scheme."""
assert len(stab) == 6
flag = AncillaRegister(2)
meas = ClassicalRegister(2)
Expand Down Expand Up @@ -1364,8 +1407,13 @@ def measure_flagged_6(


def measure_flagged_8(
qc: QuantumCircuit, stab: list[Qubit], ancilla: AncillaQubit, measurement_bit: ClBit, z_measurement=True
qc: QuantumCircuit,
stab: list[Qubit] | npt.NDArray[np.int_],
ancilla: AncillaQubit,
measurement_bit: ClBit,
z_measurement: bool = True,
) -> None:
"""Measure an 8-flagged stabilizer using an optimized scheme."""
assert len(stab) == 8
flag = AncillaRegister(3)
meas = ClassicalRegister(3)
Expand Down Expand Up @@ -1412,7 +1460,7 @@ def measure_flagged_8(
qc.measure(ancilla, measurement_bit)


def _hook_errors(stabs: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
def _hook_errors(stabs: list[npt.NDArray[np.int8]]) -> npt.NDArray[np.int8]:
"""Assuming CNOTs are executed in ascending order of qubit index, this function gives all the hook errors of the given stabilizer measurements."""
errors = []
for stab in stabs:
Expand Down
Loading

0 comments on commit eaa2eee

Please sign in to comment.