diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index cce53af1..2e471ff9 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -1,4 +1,4 @@ -name: Pytest +name: Testing on: push: branches: [main] @@ -6,7 +6,7 @@ on: branches: [main] jobs: conda-test: - name: Conda tests on ${{ matrix.os }} + name: ${{ matrix.os }} using Py ${{ matrix.python-version }} runs-on: ${{ matrix.os }} defaults: run: @@ -29,7 +29,7 @@ jobs: run: | conda info pip install -U -r requirements.txt -c constraints.txt - - name: Lint with pylint and pycodestyle + - name: Black run: | pip install -U -r requirements-dev.txt -c constraints.txt if [ "$RUNNER_OS" == "Linux" ]; then @@ -37,8 +37,7 @@ jobs: else python setup.py build_ext --inplace fi - pylint -rn mthree - pycodestyle --max-line-length=100 mthree + black mthree - name: Run tests with pytest run: | pip install pytest diff --git a/benchmarks/iterative.py b/benchmarks/iterative.py index bb3f9f38..e1ecd4c6 100644 --- a/benchmarks/iterative.py +++ b/benchmarks/iterative.py @@ -5,16 +5,16 @@ def main(): - with open('data/eagle_large_counts.json') as json_file: + with open("data/eagle_large_counts.json") as json_file: counts = json.load(json_file) mit = mthree.M3Mitigation() - mit.cals_from_file('data/eagle_large_cals.json') + mit.cals_from_file("data/eagle_large_cals.json") st = time.perf_counter() quasi = mit.apply_correction(counts, range(127), distance=3) fin = time.perf_counter() - print(fin-st) + print(fin - st) if __name__ == "__main__": diff --git a/docs/conf.py b/docs/conf.py index 35aaef11..67b12d32 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -42,12 +42,14 @@ rst_prolog = """ .. |version| replace:: {0} -""".format(m3.version.short_version) +""".format( + m3.version.short_version +) # -- Project information ----------------------------------------------------- -project = 'Mthree {}'.format(version) -copyright = '2021, Mthree Team' # pylint: disable=redefined-builtin -author = 'Mthree Development Team' +project = "Mthree {}".format(version) +copyright = "2021, Mthree Team" # pylint: disable=redefined-builtin +author = "Mthree Development Team" # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. @@ -58,21 +60,21 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.napoleon', - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', - 'sphinx.ext.extlinks', - 'nbsphinx', - 'jupyter_sphinx', + "sphinx.ext.napoleon", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx.ext.extlinks", + "nbsphinx", + "jupyter_sphinx", "qiskit_sphinx_theme", ] -templates_path = ['_templates'] +templates_path = ["_templates"] nbsphinx_timeout = 300 -nbsphinx_execute = 'always' +nbsphinx_execute = "always" -exclude_patterns = ['_build', '**.ipynb_checkpoints'] +exclude_patterns = ["_build", "**.ipynb_checkpoints"] jupyter_execute_kwargs = dict(allow_errors=False) @@ -85,7 +87,7 @@ # Autodoc # ----------------------------------------------------------------------------- -autoclass_content = 'init' +autoclass_content = "init" # If true, figures, tables and code-blocks are automatically numbered if they # have a caption. @@ -94,15 +96,13 @@ # A dictionary mapping 'figure', 'table', 'code-block' and 'section' to # strings that are used for format of figure numbers. As a special character, # %s will be replaced to figure number. -numfig_format = { - 'table': 'Table %s' -} +numfig_format = {"table": "Table %s"} # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = 'en' +language = "en" # A boolean that decides whether module names are prepended to all object names # (for object types where a “module” of some kind is defined), e.g. for @@ -113,7 +113,7 @@ # (e.g., if this is set to ['foo.'], then foo.bar is shown under B, not F). # This can be handy if you document a project that consists of a single # package. Works only for the HTML builder currently. -modindex_common_prefix = ['mthree.'] +modindex_common_prefix = ["mthree."] # -- Configuration for extlinks extension ------------------------------------ # Refer to https://www.sphinx-doc.org/en/master/usage/extensions/extlinks.html @@ -128,23 +128,25 @@ html_title = f"{project}" -#html_sidebars = {'**': ['globaltoc.html']} -html_last_updated_fmt = '%Y/%m/%d' +# html_sidebars = {'**': ['globaltoc.html']} +html_last_updated_fmt = "%Y/%m/%d" def load_tutorials(app): - dest_dir = os.path.join(app.srcdir, 'tutorials') - source_dir = os.path.dirname(app.srcdir)+'/tutorials' + dest_dir = os.path.join(app.srcdir, "tutorials") + source_dir = os.path.dirname(app.srcdir) + "/tutorials" try: copy_tree(source_dir, dest_dir) except FileNotFoundError: - warnings.warn('Copy tutorials failed.', RuntimeWarning) + warnings.warn("Copy tutorials failed.", RuntimeWarning) + def clean_tutorials(app, exc): - tutorials_dir = os.path.join(app.srcdir, 'tutorials') + tutorials_dir = os.path.join(app.srcdir, "tutorials") shutil.rmtree(tutorials_dir) + def setup(app): load_tutorials(app) - app.connect('build-finished', clean_tutorials) \ No newline at end of file + app.connect("build-finished", clean_tutorials) diff --git a/mthree/direct.py b/mthree/direct.py new file mode 100644 index 00000000..e8e59fe3 --- /dev/null +++ b/mthree/direct.py @@ -0,0 +1,81 @@ +# This code is part of Mthree. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +# pylint: disable=no-name-in-module, invalid-name +"""Direct solver routines""" +import scipy.linalg as la + +from mthree.matrix import _reduced_cal_matrix +from mthree.utils import counts_to_vector, vector_to_quasiprobs +from mthree.norms import ainv_onenorm_est_lu +from mthree.exceptions import M3Error + + +def reduced_cal_matrix(mitigator, counts, qubits, distance=None): + """Return the reduced calibration matrix used in the solution. + + Parameters: + counts (dict): Input counts dict. + qubits (array_like): Qubits on which measurements applied. + distance (int): Distance to correct for. Default=num_bits + + Returns: + ndarray: 2D array of reduced calibrations. + dict: Counts in order they are displayed in matrix. + + Raises: + M3Error: If bit-string length does not match passed number + of qubits. + """ + counts = dict(counts) + # If distance is None, then assume max distance. + num_bits = len(qubits) + if distance is None: + distance = num_bits + + # check if len of bitstrings does not equal number of qubits passed. + bitstring_len = len(next(iter(counts))) + if bitstring_len != num_bits: + raise M3Error( + "Bitstring length ({}) does not match".format(bitstring_len) + + " number of qubits ({})".format(num_bits) + ) + + cals = mitigator._form_cals(qubits) + A, counts, _ = _reduced_cal_matrix(counts, cals, num_bits, distance) + return A, counts + + +def direct_solver( + mitigator, counts, qubits, distance=None, return_mitigation_overhead=False +): + """Apply the mitigation using direct LU factorization. + + Parameters: + counts (dict): Input counts dict. + qubits (int): Qubits over which to calibrate. + distance (int): Distance to correct for. Default=num_bits + return_mitigation_overhead (bool): Returns the mitigation overhead, default=False. + + Returns: + QuasiDistribution: dict of Quasiprobabilites + """ + cals = mitigator._form_cals(qubits) + num_bits = len(qubits) + A, sorted_counts, col_norms = _reduced_cal_matrix(counts, cals, num_bits, distance) + vec = counts_to_vector(sorted_counts) + LU = la.lu_factor(A, check_finite=False) + x = la.lu_solve(LU, vec, check_finite=False) + gamma = None + if return_mitigation_overhead: + gamma = ainv_onenorm_est_lu(A, LU) + out = vector_to_quasiprobs(x, sorted_counts) + return out, col_norms, gamma diff --git a/mthree/iterative.py b/mthree/iterative.py new file mode 100644 index 00000000..dde12cf8 --- /dev/null +++ b/mthree/iterative.py @@ -0,0 +1,91 @@ +# This code is part of Mthree. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +# pylint: disable=no-name-in-module, invalid-name +"""Iterative solver routines""" +import numpy as np +import scipy.sparse.linalg as spla + +from mthree.norms import ainv_onenorm_est_iter +from mthree.matvec import M3MatVec +from mthree.utils import counts_to_vector, vector_to_quasiprobs +from mthree.exceptions import M3Error + + +def iterative_solver( + mitigator, + counts, + qubits, + distance, + tol=1e-5, + max_iter=25, + details=0, + callback=None, + return_mitigation_overhead=False, +): + """Compute solution using GMRES and Jacobi preconditioning. + + Parameters: + counts (dict): Input counts dict. + qubits (int): Qubits over which to calibrate. + tol (float): Tolerance to use. + max_iter (int): Maximum number of iterations to perform. + distance (int): Distance to correct for. Default=num_bits + details (bool): Return col norms. + callback (callable): Callback function to record iteration count. + return_mitigation_overhead (bool): Returns the mitigation overhead, default=False. + + Returns: + QuasiDistribution: dict of Quasiprobabilites + + Raises: + M3Error: Solver did not converge. + """ + cals = mitigator._form_cals(qubits) + M = M3MatVec(dict(counts), cals, distance) + L = spla.LinearOperator( + (M.num_elems, M.num_elems), + matvec=M.matvec, + rmatvec=M.rmatvec, + dtype=np.float32, + ) + diags = M.get_diagonal() + + def precond_matvec(x): + out = x / diags + return out + + P = spla.LinearOperator( + (M.num_elems, M.num_elems), precond_matvec, dtype=np.float32 + ) + vec = counts_to_vector(M.sorted_counts) + + out, error = spla.gmres( + L, + vec, + rtol=tol, + atol=tol, + maxiter=max_iter, + M=P, + callback=callback, + callback_type="legacy", + ) + if error: + raise M3Error("GMRES did not converge: {}".format(error)) + + gamma = None + if return_mitigation_overhead: + gamma = ainv_onenorm_est_iter(M, tol=tol, max_iter=max_iter) + + quasi = vector_to_quasiprobs(out, M.sorted_counts) + if details: + return quasi, M.get_col_norms(), gamma + return quasi, gamma diff --git a/mthree/mitigation.py b/mthree/mitigation.py index 09b603b2..cc6191e3 100644 --- a/mthree/mitigation.py +++ b/mthree/mitigation.py @@ -21,8 +21,6 @@ import psutil import numpy as np -import scipy.linalg as la -import scipy.sparse.linalg as spla import orjson from qiskit.providers import BackendV2 from qiskit_ibm_runtime import SamplerV2 @@ -34,10 +32,10 @@ balanced_cal_strings, balanced_cal_circuits, ) -from mthree.matrix import _reduced_cal_matrix -from mthree.utils import counts_to_vector, vector_to_quasiprobs, gmres -from mthree.norms import ainv_onenorm_est_lu, ainv_onenorm_est_iter -from mthree.matvec import M3MatVec +from mthree.direct import direct_solver as direct_solve +from mthree.direct import reduced_cal_matrix as cal_matrix +from mthree.iterative import iterative_solver + from mthree.exceptions import M3Error from mthree.classes import QuasiCollection from ._helpers import system_info @@ -108,7 +106,7 @@ def _form_cals(self, qubits): # Reverse index qubits for easier indexing later for kk, qubit in enumerate(qubits[::-1]): - cals[4 * kk: 4 * kk + 4] = self.single_qubit_cals[qubit].ravel() + cals[4 * kk : 4 * kk + 4] = self.single_qubit_cals[qubit].ravel() return cals def tensored_cals_from_system( @@ -417,9 +415,9 @@ def _grab_additional_cals( # Get the slice length circ_slice = ceil(num_circs / num_jobs) circs_list = [ - trans_qcs[kk * circ_slice: (kk + 1) * circ_slice] + trans_qcs[kk * circ_slice : (kk + 1) * circ_slice] for kk in range(num_jobs - 1) - ] + [trans_qcs[(num_jobs - 1) * circ_slice:]] + ] + [trans_qcs[(num_jobs - 1) * circ_slice :]] # Do job submission here jobs = [] if self.rep_delay: @@ -620,8 +618,8 @@ def _apply_correction( if method == "direct": st = perf_counter() - mit_counts, col_norms, gamma = self._direct_solver( - counts, qubits, distance, return_mitigation_overhead + mit_counts, col_norms, gamma = direct_solve( + self, counts, qubits, distance, return_mitigation_overhead ) dur = perf_counter() - st mit_counts.shots = shots @@ -641,7 +639,8 @@ def callback(_): if details: st = perf_counter() - mit_counts, col_norms, gamma = self._matvec_solver( + mit_counts, col_norms, gamma = iterative_solver( + self, counts, qubits, distance, @@ -660,7 +659,8 @@ def callback(_): info["col_norms"] = col_norms return mit_counts, info # pylint: disable=unbalanced-tuple-unpacking - mit_counts, gamma = self._matvec_solver( + mit_counts, gamma = iterative_solver( + self, counts, qubits, distance, @@ -694,121 +694,7 @@ def reduced_cal_matrix(self, counts, qubits, distance=None): M3Error: If bit-string length does not match passed number of qubits. """ - counts = dict(counts) - # If distance is None, then assume max distance. - num_bits = len(qubits) - if distance is None: - distance = num_bits - - # check if len of bitstrings does not equal number of qubits passed. - bitstring_len = len(next(iter(counts))) - if bitstring_len != num_bits: - raise M3Error( - "Bitstring length ({}) does not match".format(bitstring_len) - + " number of qubits ({})".format(num_bits) - ) - - cals = self._form_cals(qubits) - A, counts, _ = _reduced_cal_matrix(counts, cals, num_bits, distance) - return A, counts - - def _direct_solver( - self, counts, qubits, distance=None, return_mitigation_overhead=False - ): - """Apply the mitigation using direct LU factorization. - - Parameters: - counts (dict): Input counts dict. - qubits (int): Qubits over which to calibrate. - distance (int): Distance to correct for. Default=num_bits - return_mitigation_overhead (bool): Returns the mitigation overhead, default=False. - - Returns: - QuasiDistribution: dict of Quasiprobabilites - """ - cals = self._form_cals(qubits) - num_bits = len(qubits) - A, sorted_counts, col_norms = _reduced_cal_matrix( - counts, cals, num_bits, distance - ) - vec = counts_to_vector(sorted_counts) - LU = la.lu_factor(A, check_finite=False) - x = la.lu_solve(LU, vec, check_finite=False) - gamma = None - if return_mitigation_overhead: - gamma = ainv_onenorm_est_lu(A, LU) - out = vector_to_quasiprobs(x, sorted_counts) - return out, col_norms, gamma - - def _matvec_solver( - self, - counts, - qubits, - distance, - tol=1e-5, - max_iter=25, - details=0, - callback=None, - return_mitigation_overhead=False, - ): - """Compute solution using GMRES and Jacobi preconditioning. - - Parameters: - counts (dict): Input counts dict. - qubits (int): Qubits over which to calibrate. - tol (float): Tolerance to use. - max_iter (int): Maximum number of iterations to perform. - distance (int): Distance to correct for. Default=num_bits - details (bool): Return col norms. - callback (callable): Callback function to record iteration count. - return_mitigation_overhead (bool): Returns the mitigation overhead, default=False. - - Returns: - QuasiDistribution: dict of Quasiprobabilites - - Raises: - M3Error: Solver did not converge. - """ - cals = self._form_cals(qubits) - M = M3MatVec(dict(counts), cals, distance) - L = spla.LinearOperator( - (M.num_elems, M.num_elems), - matvec=M.matvec, - rmatvec=M.rmatvec, - dtype=np.float32, - ) - diags = M.get_diagonal() - - def precond_matvec(x): - out = x / diags - return out - - P = spla.LinearOperator( - (M.num_elems, M.num_elems), precond_matvec, dtype=np.float32 - ) - vec = counts_to_vector(M.sorted_counts) - - out, error = gmres( - L, - vec, - rtol=tol, - atol=tol, - maxiter=max_iter, - M=P, - callback=callback, - callback_type="legacy", - ) - if error: - raise M3Error("GMRES did not converge: {}".format(error)) - - gamma = None - if return_mitigation_overhead: - gamma = ainv_onenorm_est_iter(M, tol=tol, max_iter=max_iter) - - quasi = vector_to_quasiprobs(out, M.sorted_counts) - if details: - return quasi, M.get_col_norms(), gamma - return quasi, gamma + return cal_matrix(self, counts, qubits, distance) def readout_fidelity(self, qubits=None): """Compute readout fidelity for calibrated qubits. diff --git a/mthree/norms.py b/mthree/norms.py index 06289c5f..23f2be40 100644 --- a/mthree/norms.py +++ b/mthree/norms.py @@ -17,7 +17,6 @@ import scipy.sparse.linalg as spla from mthree.exceptions import M3Error -from mthree.utils import gmres def ainv_onenorm_est_lu(A, LU=None): @@ -140,12 +139,12 @@ def precond_matvec(x): v = (1.0 / dims) * np.ones(dims, dtype=np.float32) # Initial solve - v, error = gmres(L, v, rtol=tol, atol=tol, maxiter=max_iter, M=P) + v, error = spla.gmres(L, v, rtol=tol, atol=tol, maxiter=max_iter, M=P) if error: raise M3Error("Iterative solver error {}".format(error)) gamma = la.norm(v, 1) eta = np.sign(v) - x, error = gmres(LT, eta, rtol=tol, atol=tol, maxiter=max_iter, M=P) + x, error = spla.gmres(LT, eta, rtol=tol, atol=tol, maxiter=max_iter, M=P) if error: raise M3Error("Iterative solver error {}".format(error)) # loop over reasonable number of trials @@ -155,7 +154,7 @@ def precond_matvec(x): idx = np.where(np.abs(x) == x_nrm)[0][0] v = np.zeros(dims, dtype=np.float32) v[idx] = 1 - v, error = gmres(L, v, rtol=tol, atol=tol, maxiter=max_iter, M=P) + v, error = spla.gmres(L, v, rtol=tol, atol=tol, maxiter=max_iter, M=P) if error: raise M3Error("Iterative solver error {}".format(error)) gamma_prime = gamma @@ -165,7 +164,7 @@ def precond_matvec(x): break eta = np.sign(v) - x, error = gmres(LT, eta, rtol=tol, atol=tol, maxiter=max_iter, M=P) + x, error = spla.gmres(LT, eta, rtol=tol, atol=tol, maxiter=max_iter, M=P) if error: raise M3Error("Iterative solver error {}".format(error)) if la.norm(x, np.inf) == x[idx]: @@ -176,7 +175,7 @@ def precond_matvec(x): x = np.arange(1, dims + 1, dtype=np.float32) x = (-1) ** (x + 1) * (1 + (x - 1) / (dims - 1)) - x, error = gmres(L, x, rtol=tol, atol=tol, maxiter=max_iter, M=P) + x, error = spla.gmres(L, x, rtol=tol, atol=tol, maxiter=max_iter, M=P) if error: raise M3Error("Iterative solver error {}".format(error)) diff --git a/mthree/utils.py b/mthree/utils.py index fafd9a45..71e131ab 100644 --- a/mthree/utils.py +++ b/mthree/utils.py @@ -25,7 +25,6 @@ """ import numpy as np -import scipy.sparse.linalg as spla from qiskit.result import marginal_distribution as marg_dist from mthree.exceptions import M3Error @@ -36,8 +35,6 @@ ProbCollection, ) -gmres = spla.gmres - def final_measurement_mapping(circuit): """Return the final measurement mapping for the circuit. diff --git a/requirements-dev.txt b/requirements-dev.txt index 38c5f5f3..ab32b267 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,6 @@ qiskit-aer>=0.13 pytest -pylint~=3.0.2 -pycodestyle +black Sphinx>=7.0,<=8 qiskit_sphinx_theme~=1.16.0,<2 matplotlib diff --git a/setup.py b/setup.py index 5fc3dea4..e356dd16 100644 --- a/setup.py +++ b/setup.py @@ -27,26 +27,28 @@ MICRO = 0 ISRELEASED = True -VERSION = '%d.%d.%d' % (MAJOR, MINOR, MICRO) +VERSION = "%d.%d.%d" % (MAJOR, MINOR, MICRO) with open("requirements.txt") as f: REQUIREMENTS = f.read().splitlines() PACKAGES = setuptools.find_packages() -PACKAGE_DATA = {'mthree': ['*.pxd'], +PACKAGE_DATA = { + "mthree": ["*.pxd"], } -DOCLINES = __doc__.split('\n') +DOCLINES = __doc__.split("\n") DESCRIPTION = DOCLINES[0] this_dir = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(this_dir, 'README.md'), encoding='utf-8') as readme: +with open(os.path.join(this_dir, "README.md"), encoding="utf-8") as readme: LONG_DESCRIPTION = readme.read() -CYTHON_EXTS = ['converters', 'hamming', 'matrix', 'probability', 'matvec'] + \ - ['expval', 'column_testing', 'converters_testing'] -CYTHON_MODULES = ['mthree']*6 + \ - ['mthree.test']*2 -CYTHON_SOURCE_DIRS = ['mthree']*6 + \ - ['mthree/test']*2 +CYTHON_EXTS = ["converters", "hamming", "matrix", "probability", "matvec"] + [ + "expval", + "column_testing", + "converters_testing", +] +CYTHON_MODULES = ["mthree"] * 6 + ["mthree.test"] * 2 +CYTHON_SOURCE_DIRS = ["mthree"] * 6 + ["mthree/test"] * 2 # Add openmp flags OPTIONAL_FLAGS = [] @@ -56,40 +58,47 @@ if _arg == "--openmp" or _arg == "--with-openmp": WITH_OMP = True if _arg == "--with-openmp": - warnings.warn("Using '--with-openmp' to set openmp is deprecated.", - DeprecationWarning) + warnings.warn( + "Using '--with-openmp' to set openmp is deprecated.", DeprecationWarning + ) sys.argv.remove(_arg) break if WITH_OMP or os.getenv("MTHREE_OPENMP", False): WITH_OMP = True - if sys.platform == 'win32': - OPTIONAL_FLAGS = ['/openmp'] + if sys.platform == "win32": + OPTIONAL_FLAGS = ["/openmp"] else: - OPTIONAL_FLAGS = ['-fopenmp'] + OPTIONAL_FLAGS = ["-fopenmp"] OPTIONAL_ARGS = OPTIONAL_FLAGS if os.getenv("MTHREE_ARCH", False): - OPTIONAL_FLAGS.append('-march='+os.getenv("MTHREE_ARCH")) + OPTIONAL_FLAGS.append("-march=" + os.getenv("MTHREE_ARCH")) INCLUDE_DIRS = [np.get_include()] # Extra link args LINK_FLAGS = [] # If on Win and not in MSYS2 (i.e. Visual studio compile) -if (sys.platform == 'win32' and os.environ.get('MSYSTEM', None) is None): - COMPILER_FLAGS = ['/O3'] +if sys.platform == "win32" and os.environ.get("MSYSTEM", None) is None: + COMPILER_FLAGS = ["/O3"] # Everything else else: - COMPILER_FLAGS = ['-O3', '-std=c++17', '-DNPY_NO_DEPRECATED_API=NPY_1_23_API_VERSION'] + COMPILER_FLAGS = [ + "-O3", + "-std=c++17", + "-DNPY_NO_DEPRECATED_API=NPY_1_23_API_VERSION", + ] EXT_MODULES = [] # Add Cython Extensions for idx, ext in enumerate(CYTHON_EXTS): - mod = setuptools.Extension(CYTHON_MODULES[idx] + '.' + ext, - sources=[CYTHON_SOURCE_DIRS[idx] + '/' + ext + '.pyx'], - include_dirs=INCLUDE_DIRS, - extra_compile_args=COMPILER_FLAGS+OPTIONAL_FLAGS, - extra_link_args=LINK_FLAGS+OPTIONAL_ARGS, - language='c++') + mod = setuptools.Extension( + CYTHON_MODULES[idx] + "." + ext, + sources=[CYTHON_SOURCE_DIRS[idx] + "/" + ext + ".pyx"], + include_dirs=INCLUDE_DIRS, + extra_compile_args=COMPILER_FLAGS + OPTIONAL_FLAGS, + extra_link_args=LINK_FLAGS + OPTIONAL_ARGS, + language="c++", + ) EXT_MODULES.append(mod) @@ -99,15 +108,17 @@ def git_short_hash(): except: # pylint: disable=bare-except git_str = "" else: - if git_str == '+': #fixes setuptools PEP issues with versioning - git_str = '' + if git_str == "+": # fixes setuptools PEP issues with versioning + git_str = "" return git_str + FULLVERSION = VERSION if not ISRELEASED: - FULLVERSION += '.dev'+str(MICRO)+git_short_hash() + FULLVERSION += ".dev" + str(MICRO) + git_short_hash() + -def write_version_py(filename='mthree/version.py'): +def write_version_py(filename="mthree/version.py"): cnt = """\ # THIS FILE IS GENERATED FROM MTHREE SETUP.PY # pylint: disable=missing-module-docstring @@ -115,21 +126,28 @@ def write_version_py(filename='mthree/version.py'): version = '%(fullversion)s' openmp = %(with_omp)s """ - a = open(filename, 'w') + a = open(filename, "w") try: - a.write(cnt % {'version': VERSION, 'fullversion':FULLVERSION, - 'with_omp': str(WITH_OMP)}) + a.write( + cnt + % { + "version": VERSION, + "fullversion": FULLVERSION, + "with_omp": str(WITH_OMP), + } + ) finally: a.close() + local_path = os.path.dirname(os.path.abspath(sys.argv[0])) os.chdir(local_path) sys.path.insert(0, local_path) -sys.path.insert(0, os.path.join(local_path, 'mthree')) # to retrive _version +sys.path.insert(0, os.path.join(local_path, "mthree")) # to retrive _version # always rewrite _version -if os.path.exists('mthree/version.py'): - os.remove('mthree/version.py') +if os.path.exists("mthree/version.py"): + os.remove("mthree/version.py") write_version_py() @@ -137,38 +155,43 @@ def write_version_py(filename='mthree/version.py'): # Add command for running pylint from setup.py class PylintCommand(setuptools.Command): """Run Pylint on all mthree Python source files.""" - description = 'Run Pylint on mthree Python source files' + + description = "Run Pylint on mthree Python source files" user_options = [ # The format is (long option, short option, description). - ('pylint-rcfile=', None, 'path to Pylint config file')] + ("pylint-rcfile=", None, "path to Pylint config file") + ] def initialize_options(self): """Set default values for options.""" # Each user option must be listed here with their default value. - self.pylint_rcfile = '' # pylint: disable=attribute-defined-outside-init + self.pylint_rcfile = "" # pylint: disable=attribute-defined-outside-init def finalize_options(self): """Post-process options.""" if self.pylint_rcfile: assert os.path.exists(self.pylint_rcfile), ( - 'Pylint config file %s does not exist.' % self.pylint_rcfile) + "Pylint config file %s does not exist." % self.pylint_rcfile + ) def run(self): """Run command.""" - command = ['pylint'] + command = ["pylint"] if self.pylint_rcfile: - command.append('--rcfile=%s' % self.pylint_rcfile) - command.append(os.getcwd()+"/mthree") + command.append("--rcfile=%s" % self.pylint_rcfile) + command.append(os.getcwd() + "/mthree") subprocess.run(command, stderr=subprocess.STDOUT, check=False) # Add command for running PEP8 tests from setup.py class StyleCommand(setuptools.Command): """Run pep8 from setup.""" - description = 'Run style from setup' + + description = "Run style from setup" user_options = [ # The format is (long option, short option, description). - ('abc', None, 'abc')] + ("abc", None, "abc") + ] def initialize_options(self): pass @@ -178,18 +201,18 @@ def finalize_options(self): def run(self): """Run command.""" - command = 'pycodestyle --max-line-length=100 mthree' + command = "pycodestyle --max-line-length=100 mthree" subprocess.run(command, shell=True, check=False, stderr=subprocess.STDOUT) setuptools.setup( - name='mthree', + name="mthree", version=VERSION, python_requires=">=3.9", packages=PACKAGES, description=DESCRIPTION, long_description=LONG_DESCRIPTION, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", url="", author="Paul Nation", author_email="paul.nation@ibm.com", @@ -207,11 +230,10 @@ def run(self): "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering", ], - cmdclass={'lint': PylintCommand, - 'style': StyleCommand}, + cmdclass={"lint": PylintCommand, "style": StyleCommand}, install_requires=REQUIREMENTS, package_data=PACKAGE_DATA, ext_modules=cythonize(EXT_MODULES, language_level=3), include_package_data=True, - zip_safe=False + zip_safe=False, )