Skip to content

Commit

Permalink
Merge pull request #240 from nonhermitian/move-solvers
Browse files Browse the repository at this point in the history
Move solvers to own modules
  • Loading branch information
nonhermitian authored Oct 8, 2024
2 parents dbcd8ee + f7e65aa commit 57b636a
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 225 deletions.
9 changes: 4 additions & 5 deletions .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name: Pytest
name: Testing
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
conda-test:
name: Conda tests on ${{ matrix.os }}
name: ${{ matrix.os }} using Py ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
defaults:
run:
Expand All @@ -29,16 +29,15 @@ jobs:
run: |
conda info
pip install -U -r requirements.txt -c constraints.txt
- name: Lint with pylint and pycodestyle
- name: Black
run: |
pip install -U -r requirements-dev.txt -c constraints.txt
if [ "$RUNNER_OS" == "Linux" ]; then
MTHREE_OPENMP=1 python setup.py build_ext --inplace
else
python setup.py build_ext --inplace
fi
pylint -rn mthree
pycodestyle --max-line-length=100 mthree
black mthree
- name: Run tests with pytest
run: |
pip install pytest
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@


def main():
with open('data/eagle_large_counts.json') as json_file:
with open("data/eagle_large_counts.json") as json_file:
counts = json.load(json_file)

mit = mthree.M3Mitigation()
mit.cals_from_file('data/eagle_large_cals.json')
mit.cals_from_file("data/eagle_large_cals.json")

st = time.perf_counter()
quasi = mit.apply_correction(counts, range(127), distance=3)
fin = time.perf_counter()
print(fin-st)
print(fin - st)


if __name__ == "__main__":
Expand Down
58 changes: 30 additions & 28 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@

rst_prolog = """
.. |version| replace:: {0}
""".format(m3.version.short_version)
""".format(
m3.version.short_version
)

# -- Project information -----------------------------------------------------
project = 'Mthree {}'.format(version)
copyright = '2021, Mthree Team' # pylint: disable=redefined-builtin
author = 'Mthree Development Team'
project = "Mthree {}".format(version)
copyright = "2021, Mthree Team" # pylint: disable=redefined-builtin
author = "Mthree Development Team"
# -- General configuration ---------------------------------------------------

# If your documentation needs a minimal Sphinx version, state it here.
Expand All @@ -58,21 +60,21 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.napoleon',
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.mathjax',
'sphinx.ext.viewcode',
'sphinx.ext.extlinks',
'nbsphinx',
'jupyter_sphinx',
"sphinx.ext.napoleon",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.mathjax",
"sphinx.ext.viewcode",
"sphinx.ext.extlinks",
"nbsphinx",
"jupyter_sphinx",
"qiskit_sphinx_theme",
]
templates_path = ['_templates']
templates_path = ["_templates"]
nbsphinx_timeout = 300
nbsphinx_execute = 'always'
nbsphinx_execute = "always"

exclude_patterns = ['_build', '**.ipynb_checkpoints']
exclude_patterns = ["_build", "**.ipynb_checkpoints"]

jupyter_execute_kwargs = dict(allow_errors=False)

Expand All @@ -85,7 +87,7 @@
# Autodoc
# -----------------------------------------------------------------------------

autoclass_content = 'init'
autoclass_content = "init"

# If true, figures, tables and code-blocks are automatically numbered if they
# have a caption.
Expand All @@ -94,15 +96,13 @@
# A dictionary mapping 'figure', 'table', 'code-block' and 'section' to
# strings that are used for format of figure numbers. As a special character,
# %s will be replaced to figure number.
numfig_format = {
'table': 'Table %s'
}
numfig_format = {"table": "Table %s"}
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = 'en'
language = "en"

# A boolean that decides whether module names are prepended to all object names
# (for object types where a “module” of some kind is defined), e.g. for
Expand All @@ -113,7 +113,7 @@
# (e.g., if this is set to ['foo.'], then foo.bar is shown under B, not F).
# This can be handy if you document a project that consists of a single
# package. Works only for the HTML builder currently.
modindex_common_prefix = ['mthree.']
modindex_common_prefix = ["mthree."]

# -- Configuration for extlinks extension ------------------------------------
# Refer to https://www.sphinx-doc.org/en/master/usage/extensions/extlinks.html
Expand All @@ -128,23 +128,25 @@
html_title = f"{project}"


#html_sidebars = {'**': ['globaltoc.html']}
html_last_updated_fmt = '%Y/%m/%d'
# html_sidebars = {'**': ['globaltoc.html']}
html_last_updated_fmt = "%Y/%m/%d"


def load_tutorials(app):
dest_dir = os.path.join(app.srcdir, 'tutorials')
source_dir = os.path.dirname(app.srcdir)+'/tutorials'
dest_dir = os.path.join(app.srcdir, "tutorials")
source_dir = os.path.dirname(app.srcdir) + "/tutorials"

try:
copy_tree(source_dir, dest_dir)
except FileNotFoundError:
warnings.warn('Copy tutorials failed.', RuntimeWarning)
warnings.warn("Copy tutorials failed.", RuntimeWarning)


def clean_tutorials(app, exc):
tutorials_dir = os.path.join(app.srcdir, 'tutorials')
tutorials_dir = os.path.join(app.srcdir, "tutorials")
shutil.rmtree(tutorials_dir)


def setup(app):
load_tutorials(app)
app.connect('build-finished', clean_tutorials)
app.connect("build-finished", clean_tutorials)
81 changes: 81 additions & 0 deletions mthree/direct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# This code is part of Mthree.
#
# (C) Copyright IBM 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=no-name-in-module, invalid-name
"""Direct solver routines"""
import scipy.linalg as la

from mthree.matrix import _reduced_cal_matrix
from mthree.utils import counts_to_vector, vector_to_quasiprobs
from mthree.norms import ainv_onenorm_est_lu
from mthree.exceptions import M3Error


def reduced_cal_matrix(mitigator, counts, qubits, distance=None):
"""Return the reduced calibration matrix used in the solution.
Parameters:
counts (dict): Input counts dict.
qubits (array_like): Qubits on which measurements applied.
distance (int): Distance to correct for. Default=num_bits
Returns:
ndarray: 2D array of reduced calibrations.
dict: Counts in order they are displayed in matrix.
Raises:
M3Error: If bit-string length does not match passed number
of qubits.
"""
counts = dict(counts)
# If distance is None, then assume max distance.
num_bits = len(qubits)
if distance is None:
distance = num_bits

# check if len of bitstrings does not equal number of qubits passed.
bitstring_len = len(next(iter(counts)))
if bitstring_len != num_bits:
raise M3Error(
"Bitstring length ({}) does not match".format(bitstring_len)
+ " number of qubits ({})".format(num_bits)
)

cals = mitigator._form_cals(qubits)
A, counts, _ = _reduced_cal_matrix(counts, cals, num_bits, distance)
return A, counts


def direct_solver(
mitigator, counts, qubits, distance=None, return_mitigation_overhead=False
):
"""Apply the mitigation using direct LU factorization.
Parameters:
counts (dict): Input counts dict.
qubits (int): Qubits over which to calibrate.
distance (int): Distance to correct for. Default=num_bits
return_mitigation_overhead (bool): Returns the mitigation overhead, default=False.
Returns:
QuasiDistribution: dict of Quasiprobabilites
"""
cals = mitigator._form_cals(qubits)
num_bits = len(qubits)
A, sorted_counts, col_norms = _reduced_cal_matrix(counts, cals, num_bits, distance)
vec = counts_to_vector(sorted_counts)
LU = la.lu_factor(A, check_finite=False)
x = la.lu_solve(LU, vec, check_finite=False)
gamma = None
if return_mitigation_overhead:
gamma = ainv_onenorm_est_lu(A, LU)
out = vector_to_quasiprobs(x, sorted_counts)
return out, col_norms, gamma
91 changes: 91 additions & 0 deletions mthree/iterative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# This code is part of Mthree.
#
# (C) Copyright IBM 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=no-name-in-module, invalid-name
"""Iterative solver routines"""
import numpy as np
import scipy.sparse.linalg as spla

from mthree.norms import ainv_onenorm_est_iter
from mthree.matvec import M3MatVec
from mthree.utils import counts_to_vector, vector_to_quasiprobs
from mthree.exceptions import M3Error


def iterative_solver(
mitigator,
counts,
qubits,
distance,
tol=1e-5,
max_iter=25,
details=0,
callback=None,
return_mitigation_overhead=False,
):
"""Compute solution using GMRES and Jacobi preconditioning.
Parameters:
counts (dict): Input counts dict.
qubits (int): Qubits over which to calibrate.
tol (float): Tolerance to use.
max_iter (int): Maximum number of iterations to perform.
distance (int): Distance to correct for. Default=num_bits
details (bool): Return col norms.
callback (callable): Callback function to record iteration count.
return_mitigation_overhead (bool): Returns the mitigation overhead, default=False.
Returns:
QuasiDistribution: dict of Quasiprobabilites
Raises:
M3Error: Solver did not converge.
"""
cals = mitigator._form_cals(qubits)
M = M3MatVec(dict(counts), cals, distance)
L = spla.LinearOperator(
(M.num_elems, M.num_elems),
matvec=M.matvec,
rmatvec=M.rmatvec,
dtype=np.float32,
)
diags = M.get_diagonal()

def precond_matvec(x):
out = x / diags
return out

P = spla.LinearOperator(
(M.num_elems, M.num_elems), precond_matvec, dtype=np.float32
)
vec = counts_to_vector(M.sorted_counts)

out, error = spla.gmres(
L,
vec,
rtol=tol,
atol=tol,
maxiter=max_iter,
M=P,
callback=callback,
callback_type="legacy",
)
if error:
raise M3Error("GMRES did not converge: {}".format(error))

gamma = None
if return_mitigation_overhead:
gamma = ainv_onenorm_est_iter(M, tol=tol, max_iter=max_iter)

quasi = vector_to_quasiprobs(out, M.sorted_counts)
if details:
return quasi, M.get_col_norms(), gamma
return quasi, gamma
Loading

0 comments on commit 57b636a

Please sign in to comment.