Skip to content

Commit

Permalink
Finished simulation and added codes.
Browse files Browse the repository at this point in the history
  • Loading branch information
pehamTom committed Jun 21, 2024
1 parent 703504e commit f6f112c
Show file tree
Hide file tree
Showing 18 changed files with 124 additions and 39 deletions.
1 change: 1 addition & 0 deletions src/mqt/qecc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._version import version as __version__
from .analog_information_decoding.simulators.analog_tannergraph_decoding import AnalogTannergraphDecoder, AtdSimulator
from .analog_information_decoding.simulators.quasi_single_shot_v2 import QssSimulator

from .pyqecc import (
Code,
Decoder,
Expand Down
68 changes: 66 additions & 2 deletions src/mqt/qecc/ft_stateprep/code.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
"""Class for representing quantum error correction codes."""

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING

import numpy as np

from ldpc import mod2
from mqt.qecc.cc_decoder.hexagonal_color_code import HexagonalColorCode

try:
from importlib import resources as impresources
except ImportError:
# Try backported to PY<37 `importlib_resources`.
import importlib_resources as impresources

from . import sample_codes # relative-import the *package* containing the templates


if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt
Expand Down Expand Up @@ -59,3 +68,58 @@ def get_z_syndrome(self, error: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]:
def is_self_dual(self) -> bool:
"""Check if the code is self-dual."""
return mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, self.Hz]))

@staticmethod
def from_code_name(code_name: str, distance: int=None) -> CSSCode:
"""Return CSSCode object for a known code.
The following codes are supported:
- [[7, 1, 3]] Steane (\"Steane\")
- [[15, 1, 3]] tetrahedral code (\"Tetrahedral\")
- [[15, 7, 3]] Hamming code (\"Hamming\")
- [[9, 1, 3]] Shore code (\"Shor\")
- [[9, 1, 3]] rotated surface code (\"Surface, 3\")
- [[25, 1, 5]] rotated surface code (\"Surface, 5\")
- [[17, 1, 5]] 4,8,8 color code (\"CC_4_8_8, 5\")
- 6,6,6 color code for arbitrary distances (\"CC_6_6_6, d\")
Args:
code_name: The name of the code.
"""
prefix = impresources.files(sample_codes)
paths = {
"steane": prefix / "steane/",
"tetrahedral": prefix / "tetrahedral/",
"hamming": prefix / "hamming/",
"shor": prefix / "shor/",
"surface, 3": prefix / "rotated_surface_d3/",
"surface, 5": prefix / "rotated_surface_d5/",
"cc_4_8_8 5": prefix / "cc_4_8_8_d5/",
}

distances = {
"steane": 3,
"tetrahedral": 3,
"hamming": 3,
"shor": 3,
"cc_4_8_8 5": 5,
}

code_name = code_name.lower()
if code_name == "cc_6_6_6":
if distance is None:
raise ValueError("Distance is not specified for CC_6_6_6")
cc = HexagonalColorCode(distance)
cc.construct_layout()
return CSSCode(distance, cc.H, cc.H)

elif code_name in paths:
hx = np.load(paths[code_name] / "hx.npy")
hz = np.load(paths[code_name] / "hz.npy")
if code_name in distances:
distance = distances[code_name]
elif distance is None:
raise ValueError(f"Distance is not specified for {code_name}")
return CSSCode(distance, hx, hz)
else:
raise ValueError(f"Unknown code name: {code_name}")
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added src/mqt/qecc/ft_stateprep/sample_codes/shor/hx.npy
Binary file not shown.
Binary file added src/mqt/qecc/ft_stateprep/sample_codes/shor/hz.npy
Binary file not shown.
Binary file added src/mqt/qecc/ft_stateprep/sample_codes/steane/hx.npy
Binary file not shown.
Binary file added src/mqt/qecc/ft_stateprep/sample_codes/steane/hz.npy
Binary file not shown.
Binary file not shown.
Binary file not shown.
77 changes: 47 additions & 30 deletions src/mqt/qecc/ft_stateprep/simulation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
"""Simulation of Non-deterministic fault tolerant state preparation."""

from __future__ import annotations
from typing import TYPE_CHECKING

import stim
from qiskit import QuantumCircuit
from qiskit.converters import circuit_to_dag, dag_to_circuit
import numpy as np
from collections import defaultdict

from state_prep import NDFTStatePrepCircuit
from code import CSSCode
from .code import CSSCode

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt

class NoisyNDFTStatePrepSimulator:
"""A noisy state preparation circuit."""

def __init__(self, state_prep_circ: QuantumCircuit, code: CSSCode, p: float, zero_state: bool = False):
def __init__(self, state_prep_circ: QuantumCircuit, code: CSSCode, p: float, zero_state: bool = True):
self.circ = state_prep_circ
self.num_qubits = state_prep_circ.num_qubits
self.code = code
Expand All @@ -27,8 +29,8 @@ def __init__(self, state_prep_circ: QuantumCircuit, code: CSSCode, p: float, zer
self.data_measurements = []
self.n_measurements = 0
self.stim_circ = None
self.set_p(p)
self.decoder = LUTDecoder(code)
self.set_p(p)

def set_p(self, p: float) -> None:
"""Set the error rate."""
Expand All @@ -39,7 +41,12 @@ def set_p(self, p: float) -> None:
self.n_measurements = 0
self.p = p
self.stim_circ = self.to_stim_circ()
self.num_qubits = self.stim_circ.num_qubits - len(self.verification_measurements)
self.measure_stabilizers()
if self.zero_state:
self.measure_z()
else:
self.measure_x()

def to_stim_circ(self) -> stim.Circuit:
"""Convert a QuantumCircuit to a noisy STIM circuit.
Expand Down Expand Up @@ -114,17 +121,17 @@ def measure_stabilizers(self) -> stim.Circuit:
"""
for check in self.code.Hx:
supp = _support(check)
anc = self.stim_circ.num_qubits()
anc = self.stim_circ.num_qubits
self.stim_circ.append_operation("H", [anc])
for q in supp:
self.stim_circ.append_operation("CX", [anc, q])
self.stim_circ.append_operation("MRX", [anc])
self.x_measurements.append(self.n_measurements)
self.n_measurements += 1

for check in self.code.z_checks:
for check in self.code.Hz:
supp = _support(check)
anc = self.stim_circ.num_qubits()
anc = self.stim_circ.num_qubits
for q in supp:
self.stim_circ.append_operation("CX", [q, anc])
self.stim_circ.append_operation("MRZ", [anc])
Expand All @@ -135,75 +142,85 @@ def measure_z(self) -> None:
"""Measure all data qubits in the Z basis."""
self.data_measurements = [self.n_measurements + i for i in range(self.num_qubits)]
self.n_measurements += self.num_qubits
self.circuit.append_operation("MRZ", [q for q in range(self.num_qubits)])
self.stim_circ.append_operation("MRZ", [q for q in range(self.num_qubits)])

def measure_x(self) -> None:
"""Measure all data qubits in the X basis."""
self.data_measurements = [self.n_measurements + i for i in range(self.num_qubits)]
self.n_measurements += self.num_qubits
self.circuit.append_operation("MRX", [q for q in range(self.num_qubits)])
self.stim_circ.append_operation("MRX", [q for q in range(self.num_qubits)])

def logical_error_rate(self, shots=100000, shots_per_batch=100000, at_least_min_errors=True, min_errors=500):
"""Estimate the logical error rate of the code.
Args:
shots: The number of shots to use.
shots_per_batch: The number of shots per batch.
at_least_min_errors: Whether to continue simulating until at least min_errors are found.
min_errors: The minimum number of errors to find before stopping.
"""
batch = min(shots_per_batch, shots)
p_l = 0
r_a = 0

num_logical_errors = 0

if self.zero_state:
self.decoder.generate_x_lut()
self.decoder.generate_x_LUT()
else:
self.decoder.generate_z_lut()
self.decoder.generate_z_LUT()

i = 1
while i <= int(np.ceil(shots/batch)) or at_least_min_errors:
num_logical_errors_batch, discarded_batch = self._simulate_batch(batch)

p_l_batch = num_logical_errors-batch/(batch-discarded_batch)
p_l_batch = num_logical_errors_batch/(batch-discarded_batch)
r_a_batch = 1-discarded_batch/batch

# Update statistics
num_logical_errors += num_logical_errors-batch
num_logical_errors += num_logical_errors_batch
p_l = ((i-1)*p_l + p_l_batch) / i
r_a = ((i-1)*r_a + r_a_batch)/i

if at_least_min_errors and num_logical_errors >= min_errors:
break
i += 1

return p_l, r_a, num_logical_errors, i*batch

def _simulate_batch(self, shots=1024):
sampler = self.stim_circuit.compile_sampler()
sampler = self.stim_circ.compile_sampler()
detection_events = sampler.sample(shots)

# Filter events where the verification circuit flagged
index_array = np.where(np.all(np.logical_not(detection_events[:, self.verification_measurements]), axis=1))[0]
filtered_events = detection_events[index_array][:, self.x_measurements + self.z_measurements].astype(int)
filtered_events = detection_events[index_array].astype(np.int8)

if len(filtered_events) == 0: # All events were discarded
return 0, shots

state = filtered_events[:, self.data_measurements]

if self.zero_state:
checks = filtered_events[:, self.x_measurements]
checks = filtered_events[:, self.z_measurements]
observables = self.code.Lz
estimates = self.lut.batch_decode_x(checks)
estimates = self.decoder.batch_decode_x(checks)
else:
checks = filtered_events[:, self.z_measurements]
checks = filtered_events[:, self.x_measurements]
observables = self.code.Lx
estimates = self.lut.batch_decode_z(checks)
estimates = self.decoder.batch_decode_z(checks)

corrected = state + estimates
# print(np.sum(np.any(corrected @ observables.T % 2!=0, axis=1)))


num_discarded = detection_events.shape[0]-filtered_events.shape[0]
num_logical_errors = np.sum(np.any(corrected @ observables.T % 2!=0, axis=1)) # number of non-commuting corrected states
return num_logical_errors, num_discarded


class LUTDecoder:
"""Lookup table decoder for a CSSState"""
"""Lookup table decoder for a CSSState."""

def __init__(self, code: CSSCode, init_LUTs: bool = True):
self.code = code
Expand Down Expand Up @@ -238,16 +255,16 @@ def generate_x_LUT(self) -> None:
if self.x_LUT is not None:
return

self.x_LUT = self._generate_LUT(self.code.x_checks)
if self.code.is_self_dual:
self.x_LUT = self._generate_LUT(self.code.Hx)
if self.code.is_self_dual():
self.z_LUT = self.x_LUT

def generate_z_LUT(self) -> None:
"""Generate the lookup table for the Z errors."""
if self.z_LUT is not None:
return
self.z_LUT = self._generate_LUT(self.code.z_checks)
if self.code.is_self_dual:
self.z_LUT = self._generate_LUT(self.code.Hz)
if self.code.is_self_dual():
self.z_LUT = self.x_LUT

def _generate_LUT(self, checks: np.array) -> dict:
Expand All @@ -261,7 +278,7 @@ def _generate_LUT(self, checks: np.array) -> dict:
for i in range(0, 2**n_qubits):
state = np.array(list(np.binary_repr(i).zfill(n_qubits))).astype(np.int8)
syndrome = checks @ state % 2
lut[syndrome.tobytes()].append(state)
lut[syndrome.astype(np.int8).tobytes()].append(state)

# Sort according to weight
for key, v in lut.items():
Expand Down
17 changes: 10 additions & 7 deletions src/mqt/qecc/ft_stateprep/state_prep.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Synthesizing state preparation circuits for CSS codes."""

from __future__ import annotations
from typing import TYPE_CHECKING

from ldpc import mod2
from code import CSSCode
import numpy as np
from qiskit import AncillaRegister, ClassicalRegister, QuantumCircuit, QuantumRegister
from qiskit.quantum_info import PauliList
Expand All @@ -13,6 +13,8 @@
import z3
import multiprocessing

from .code import CSSCode

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt
SymOrBool = z3.BoolRef | bool
Expand Down Expand Up @@ -448,11 +450,11 @@ def gate_optimal_verification_circuit(sp_circ: StatePrepCircuit, n_errors: int =
if layers is None:
return None

measured_circ = _measure_stabs(sp_circ.circ, [measurement for layer in layers for measurement in layer])
measured_circ = _measure_stabs(sp_circ.circ, [measurement for layer in layers for measurement in layer], sp_circ.zero_state)
return measured_circ


def _measure_stabs(circ: QuantumCircuit, measurements: list(npt.NDArray[np.int_])) -> QuantumCircuit:
def _measure_stabs(circ: QuantumCircuit, measurements: list(npt.NDArray[np.int_]), z_measurements=True) -> QuantumCircuit:
# Create the verification circuit
num_anc = len(measurements)
q = QuantumRegister(circ.num_qubits, "q")
Expand All @@ -463,15 +465,16 @@ def _measure_stabs(circ: QuantumCircuit, measurements: list(npt.NDArray[np.int_]
measured_circ.compose(circ, inplace=True)
current_anc = 0
for measurement in measurements:
if not self.zero_state:
if not z_measurements:
measured_circ.h(q)
for qubit in np.where(measurement == 1)[0]:
if self.zero_state:
if z_measurements:
measured_circ.cx(q[qubit], anc[current_anc])
else:
measured_circ.cx(anc[current_anc], q[qubit])
if not self.zero_state:
if not z_measurements:
measured_circ.h(q)
measured_circ.measure(anc[current_anc], c[current_anc])
current_anc += 1
return measured_circ

Expand Down Expand Up @@ -499,7 +502,7 @@ def vars_to_stab(measurement):
measurement_stabs = [vars_to_stab(vars_) for vars_ in measurement_vars]

# assert that each error is detected
errors = self._compute_fault_set(num_errors)
errors = self.compute_fault_set(num_errors)
solver.add(z3.And([z3.PbGe([(_odd_overlap(measurement, error), 1)
for measurement in measurement_stabs], 1)
for error in errors]))
Expand Down

0 comments on commit f6f112c

Please sign in to comment.