diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 54e8eb2d..116a6faa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,39 +27,11 @@ repos: - id: requirements-txt-fixer - id: sort-simple-yaml - id: trailing-whitespace - - repo: https://github.com/asottile/pyupgrade - rev: v3.15.2 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.5 hooks: - - id: pyupgrade - args: [--py38-plus] - - repo: https://github.com/asottile/yesqa - rev: v1.5.0 - hooks: - - id: yesqa - - repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - args: [--profile, black, --float-to-top, --color] - - repo: https://github.com/psf/black - rev: 24.4.2 - hooks: - - id: black - - repo: https://github.com/PyCQA/autoflake - rev: v2.3.1 - hooks: - - id: autoflake - - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - additional_dependencies: [flake8-bugbear] - args: ["--ignore=W503,B015,B028"] - - repo: https://github.com/PyCQA/bandit - rev: 1.7.8 - hooks: - - id: bandit - args: [--skip, "B101"] + - id: ruff + - id: ruff-format - repo: https://github.com/codespell-project/codespell rev: v2.3.0 hooks: @@ -82,17 +54,6 @@ repos: - id: nbqa-check-ast additional_dependencies: [pre-commit-hooks] args: [--nbqa-dont-skip-bad-cells] - - id: nbqa-pyupgrade - additional_dependencies: [pyupgrade==v3.15.0] - args: [--py37-plus] - - id: nbqa-isort - additional_dependencies: [isort==5.13.2] - args: [--profile=black, --float-to-top] - - id: nbqa-black - additional_dependencies: [black==24.1.1] - - id: nbqa-pydocstyle - additional_dependencies: [pydocstyle==6.3.0] - args: ["--ignore=D100,D103"] - - id: nbqa-flake8 - additional_dependencies: [flake8==7.0.0, flake8-bugbear] - args: [--max-line-length=88, "--ignore=E203,E722,F821,W503,B001,B015"] + - id: nbqa-ruff + additional_dependencies: [ruff==v0.4.5] + args: ["--ignore=E722,F821,S110"] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ef7fd565..5c06742d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -7,7 +7,6 @@ Please follow these guidelines when contributing to this project. ## Developer Instructions ```bash -pip install -r requirements.txt pip install -r requirements-dev.txt python setup.py develop @@ -30,6 +29,12 @@ new code, and it can also be run at any time using the following command: pre-commit run --all ``` +or run on any files not yet committed to the repository using + +```bash +pre-commit run --files ... +``` + ### Running the tests (Requires `requirements-dev.txt` to be installed) diff --git a/README.md b/README.md index 57373594..d00a4a75 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,11 @@ # becquerel -[![tests](https://github.com/lbl-anp/becquerel/actions/workflows/tests.yaml/badge.svg?branch=)](https://github.com/lbl-anp/becquerel/actions/workflows/tests.yaml) -[![Coverage Status](https://coveralls.io/repos/github/lbl-anp/becquerel/badge.svg?branch=main)](https://coveralls.io/github/lbl-anp/becquerel?branch=main) +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![PyPI version](https://img.shields.io/pypi/v/becquerel.svg)](https://pypi.org/project/becquerel) [![PyPI pyversions](https://img.shields.io/pypi/pyversions/becquerel.svg)](https://pypi.org/project/becquerel) -[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![PyPI license](https://img.shields.io/pypi/l/becquerel.svg)](https://pypi.python.org/project/becquerel) +[![tests](https://github.com/lbl-anp/becquerel/actions/workflows/tests.yaml/badge.svg?branch=)](https://github.com/lbl-anp/becquerel/actions/workflows/tests.yaml) +[![Coverage Status](https://coveralls.io/repos/github/lbl-anp/becquerel/badge.svg?branch=main)](https://coveralls.io/github/lbl-anp/becquerel?branch=main) Becquerel is a Python package for analyzing nuclear spectroscopic measurements. The core functionalities are reading and writing different @@ -50,12 +51,13 @@ The dependencies `beautifulsoup4`, `lxml` and `html5lib` are necessary for [`pandas`][1]. Developers require additional requirements which are listed in -`requirements-dev.txt`. We use [`pytest`][2] for unit testing, [`black`][3] for -code formatting and are converting to [`numpydoc`][4]. +`requirements-dev.txt`. We use [`pytest`][2] for unit testing, [`ruff`][3] for +code formatting and linting, and are planning to eventually support +[`numpydoc`][4] docstrings. [1]: https://pandas.pydata.org/pandas-docs/stable/install.html#dependencies [2]: https://docs.pytest.org/en/latest/ -[3]: https://black.readthedocs.io/en/stable/ +[3]: https://docs.astral.sh/ruff/ [4]: https://numpydoc.readthedocs.io/en/latest/format.html ## Copyright Notice diff --git a/becquerel/__init__.py b/becquerel/__init__.py index 71186260..5d48046c 100644 --- a/becquerel/__init__.py +++ b/becquerel/__init__.py @@ -34,44 +34,44 @@ warnings.simplefilter("default", DeprecationWarning) __all__ = [ - "__description__", - "__url__", - "__version__", - "__license__", - "__copyright__", - "core", - "utils", - "fitting", "AutoCalibrator", "AutoCalibratorError", - "LinearEnergyCal", - "EnergyCalError", "BadInput", "Calibration", "CalibrationError", "CalibrationWarning", + "Element", + "EnergyCalError", "Fitter", + "GaussianPeakFilter", + "Isotope", + "IsotopeQuantity", + "LinearEnergyCal", "PeakFilter", "PeakFilterError", - "GaussianPeakFilter", "PeakFinder", "PeakFinderError", - "SpectrumPlotter", "PlottingError", - "rebin", "RebinError", "RebinWarning", "Spectrum", "SpectrumError", - "UncalibratedError", + "SpectrumPlotter", "SpectrumWarning", + "UncalibratedError", "UncertaintiesError", + "__copyright__", + "__description__", + "__license__", + "__url__", + "__version__", + "core", + "fitting", + "materials", + "nndc", "parsers", + "rebin", "tools", - "nndc", + "utils", "xcom", - "materials", - "Element", - "Isotope", - "IsotopeQuantity", ] diff --git a/becquerel/__metadata__.py b/becquerel/__metadata__.py index b35391c9..03453179 100644 --- a/becquerel/__metadata__.py +++ b/becquerel/__metadata__.py @@ -1,6 +1,5 @@ """becquerel package metadata.""" -__name__ = "becquerel" __author__ = "The Becquerel Development Team" __maintainer__ = __author__ __email__ = "becquerel-dev@lbl.gov" diff --git a/becquerel/core/autocal.py b/becquerel/core/autocal.py index 4d04acb0..4f4e8eaf 100644 --- a/becquerel/core/autocal.py +++ b/becquerel/core/autocal.py @@ -141,7 +141,7 @@ def find_best_gain( snrs = np.array(snrs) n_req = len(required_energies) # make sure the required and optional sets do not overlap - optional = sorted(list(set(optional) - set(required_energies))) + optional = sorted(set(optional) - set(required_energies)) n_opt = len(optional) n_set = n_req + n_opt best_fom = None @@ -187,9 +187,9 @@ def find_best_gain( "Valid calibration found:\n" f"FOM: {fom:15.9f}" f" gain: {gain:6.3f}" - f" ergs: {str(comb_erg):50s}" - f" de: {str(de):50s}" - f" chans: {str(comb_chan):40s}" + f" ergs: {comb_erg!s:50s}" + f" de: {de!s:50s}" + f" chans: {comb_chan!s:40s}" ) if best_fom is None: best_fom = fom + 1.0 @@ -204,15 +204,15 @@ def find_best_gain( "Best calibration so far:\n" f"FOM: {best_fom:15.9f}" f" gain: {best_gain:6.3f}" - f" ergs: {str(best_ergs):50s}" - f" de: {str(de):50s}" - f" chans: {str(best_chans):40s}" + f" ergs: {best_ergs!s:50s}" + f" de: {de!s:50s}" + f" chans: {best_chans!s:40s}" ) n_set -= 1 if best_gain is None: return None else: - print("found best gain: %f keV/channel" % best_gain) + print(f"found best gain: {best_gain:f} keV/channel") return { "gain": best_gain, "channels": best_chans, diff --git a/becquerel/core/calibration.py b/becquerel/core/calibration.py index f6df17d5..d3ee4891 100644 --- a/becquerel/core/calibration.py +++ b/becquerel/core/calibration.py @@ -53,15 +53,15 @@ def _validate_domain_range(domain, rng): # must be length-2 iterables try: len(domain) - except TypeError: - raise CalibrationError(f"Domain must be length-2 iterable: {domain}") + except TypeError as exc: + raise CalibrationError(f"Domain must be length-2 iterable: {domain}") from exc domain = np.asarray(domain) if not (len(domain) == 2 and domain.ndim == 1): raise CalibrationError(f"Domain must be length-2 iterable: {domain}") try: len(rng) - except TypeError: - raise CalibrationError(f"Range must be length-2 iterable: {rng}") + except TypeError as exc: + raise CalibrationError(f"Range must be length-2 iterable: {rng}") from exc rng = np.asarray(rng) if not (len(rng) == 2 and rng.ndim == 1): raise CalibrationError(f"Range must contain two values: {rng}") @@ -230,10 +230,10 @@ def _validate_expression( # apply black formatting for consistency and error checking try: expression = black.format_str(expression, mode=black.FileMode()) - except (black.InvalidInput, blib2to3.pgen2.tokenize.TokenError): + except (black.InvalidInput, blib2to3.pgen2.tokenize.TokenError) as exc: raise CalibrationError( f"Error while running black on expression:\n{expression}" - ) + ) from exc # make sure `ind_var` appears in the formula if ind_var not in ["x", "y"]: @@ -252,10 +252,10 @@ def _validate_expression( # make sure each parameter appears at least once try: param_indices = _param_indices(expression) - except ValueError: + except ValueError as exc: raise CalibrationError( f"Unable to extract indices to parameters:\n{expression}" - ) + ) from exc if len(param_indices) > 0: if param_indices.min() != 0: raise CalibrationError( @@ -288,11 +288,11 @@ def _validate_expression( domain=domain, rng=rng, ) - except CalibrationError: + except CalibrationError as exc: raise CalibrationError( f"Cannot evaluate expression for float {ind_var} = {x_val}:\n" f"{expression}\n{safe_eval.symtable['x']}" - ) + ) from exc try: _eval_expression( expression, @@ -303,11 +303,11 @@ def _validate_expression( domain=domain, rng=rng, ) - except CalibrationError: + except CalibrationError as exc: raise CalibrationError( f"Cannot evaluate expression for array {ind_var} = {x_arr}:\n" f"{expression}\n{safe_eval.symtable['x']}" - ) + ) from exc return expression.strip() @@ -628,7 +628,7 @@ def __repr__(self): if len(self.attrs) > 0: for key in self.attrs: result += ", " - result += f"{key}={repr(self.attrs[key])}" + result += f"{key}={self.attrs[key]!r}" result += ")" return result @@ -907,7 +907,7 @@ def read(cls, name): ------- calibration : becquerel.Calibration """ - dsets, attrs, skipped = io.h5.read_h5(name) + dsets, attrs, _ = io.h5.read_h5(name) if "params" not in dsets: raise CalibrationError('Expected dataset "params"') if "expression" not in dsets: @@ -1237,7 +1237,7 @@ def plot(self, ax=None): has_points = self.points_x.size > 0 if ax is None: - fig, ax = plt.subplots(1 + has_points, 1, sharex=True) + _, ax = plt.subplots(1 + has_points, 1, sharex=True) if has_points: assert ax.shape == (2,) diff --git a/becquerel/core/energycal.py b/becquerel/core/energycal.py index f5ac0b07..221ac43f 100644 --- a/becquerel/core/energycal.py +++ b/becquerel/core/energycal.py @@ -1,4 +1,4 @@ -""""Energy calibration classes""" +"""Energy calibration classes""" import warnings from abc import ABCMeta, abstractmethod, abstractproperty @@ -73,8 +73,8 @@ def from_points(cls, chlist, kevlist, include_origin=False): try: cond = len(chlist) != len(kevlist) - except TypeError: - raise BadInput("Inputs must be one dimensional iterables") + except TypeError as exc: + raise BadInput("Inputs must be one dimensional iterables") from exc if cond: raise BadInput("Channels and energies must be same length") @@ -86,8 +86,8 @@ def from_points(cls, chlist, kevlist, include_origin=False): for ch, kev in zip(chlist, kevlist): try: cal.new_calpoint(ch, kev) - except (ValueError, TypeError): - raise BadInput("Inputs must be one dimensional iterables") + except (ValueError, TypeError) as exc: + raise BadInput("Inputs must be one dimensional iterables") from exc cal.update_fit() return cal @@ -306,7 +306,7 @@ def plot(self, ax=None): has_points = self.channels.size > 0 if ax is None: - fig, ax = plt.subplots(1 + has_points, 1, sharex=True) + _, ax = plt.subplots(1 + has_points, 1, sharex=True) if has_points: assert ax.shape == (2,) @@ -389,8 +389,10 @@ def slope(self): try: return self._coeffs["b"] - except KeyError: - raise EnergyCalError("Slope coefficient not yet supplied or calculated.") + except KeyError as exc: + raise EnergyCalError( + "Slope coefficient not yet supplied or calculated." + ) from exc @property def offset(self): @@ -398,8 +400,10 @@ def offset(self): try: return self._coeffs["c"] - except KeyError: - raise EnergyCalError("Offset coefficient not yet supplied or calculated.") + except KeyError as exc: + raise EnergyCalError( + "Offset coefficient not yet supplied or calculated." + ) from exc def _ch2kev(self, ch): """Convert scalar OR np.array of channel(s) to energies. diff --git a/becquerel/core/fitting.py b/becquerel/core/fitting.py index 60527709..50a6109c 100644 --- a/becquerel/core/fitting.py +++ b/becquerel/core/fitting.py @@ -1,7 +1,7 @@ import inspect import warnings -import matplotlib +import matplotlib as mpl import matplotlib.pyplot as plt import numdifftools as nd import numpy as np @@ -17,8 +17,7 @@ FWHM_SIG_RATIO = np.sqrt(8 * np.log(2)) # 2.35482 SQRT_TWO = np.sqrt(2) # 1.414213562 COLORS = [ - matplotlib.colors.to_rgb(c) - for c in ["C0", "C2", "C4", "C5", "C6", "C7", "C8", "C9"] + mpl.colors.to_rgb(c) for c in ["C0", "C2", "C4", "C5", "C6", "C7", "C8", "C9"] ] @@ -520,15 +519,15 @@ def __init__(self, model, x=None, y=None, y_unc=None, dx=None, roi=None, mask=No def __str__(self): return ( "bq.Fitter instance\n" - + f" name: {self.name}\n" - + f" model: {self.model}\n" - + f" x: {self.x}\n" - + f" y: {self.y}\n" - + f" y_unc: {self.y_unc}\n" - + f" xmode: {self.xmode}\n" - + f" ymode: {self.ymode}\n" - + f" dx: {self.dx}\n" - + f" roi: {self.roi}" + f" name: {self.name}\n" + f" model: {self.model}\n" + f" x: {self.x}\n" + f" y: {self.y}\n" + f" y_unc: {self.y_unc}\n" + f" xmode: {self.xmode}\n" + f" ymode: {self.ymode}\n" + f" dx: {self.dx}\n" + f" roi: {self.roi}" ) __repr__ = __str__ @@ -554,7 +553,7 @@ def y_unc(self): if self._y_unc is None: warnings.warn( "No y uncertainties (y_unc) provided. The fit will not be " - + "weighted causing in poor results at low counting statistics.", + "weighted causing in poor results at low counting statistics.", FittingWarning, ) return self._y_unc @@ -571,8 +570,8 @@ def y_unc(self, y_unc): min_v = np.min(self._y_unc[self._y_unc > 0.0]) warnings.warn( "Negative or zero uncertainty not supported. Changing " - + f"them to {min_v}. If you have Poisson data, " - + "this should be 1." + f"them to {min_v}. If you have Poisson data, " + "this should be 1." ) self._y_unc[self._y_unc <= 0.0] = min_v else: @@ -721,8 +720,8 @@ def _make_model(self, model): if m_instance.prefix in model_prefixes: raise FittingError( "A model prefix is not unique: " - + f"{m_instance.prefix} " - + f"All models: {model_translated}" + f"{m_instance.prefix} " + f"All models: {model_translated}" ) model_prefixes.add(m_instance.prefix) models.append(m_instance) @@ -745,7 +744,7 @@ def _guess_param_defaults(self, **kwargs): if isinstance(p, Parameters): p = _parameters_to_bq_guess(p) elif p is None: - raise TypeError() + raise TypeError params += p return params @@ -817,8 +816,8 @@ def fit(self, backend="lmfit", guess=None, limits=None): elif self.backend in ["iminuit", "minuit"]: raise NotImplementedError( f"Backend {self.backend} with least-squares loss not yet " - + f"supported. Use {self.backend}-pml for Poisson loss or " - + "lmfit for least-squares." + f"supported. Use {self.backend}-pml for Poisson loss or " + "lmfit for least-squares." ) elif self.backend in ["iminuit-pml", "minuit-pml"]: @@ -926,7 +925,7 @@ def model_loss(*args): self._best_values, self._init_values = {}, {} for i in range(self.result.npar): p = self.result.parameters[i] - self._best_values[p] = self.result.values[i] + self._best_values[p] = self.result.values[i] # noqa: PD011 self._init_values[p] = self.result.init_params[p].value # Arg order sanity checks @@ -1002,7 +1001,7 @@ def calc_area_and_unc(self, component=None, x=None) -> ufloat: def _calc_area(param_vec, **kwargs): """Internal function to compute the area given the fit values.""" - param_dict = {name: val for (name, val) in zip(kwargs["names"], param_vec)} + param_dict = dict(zip(kwargs["names"], param_vec)) return kwargs["model"].eval(x=kwargs["xvals"], **param_dict).sum() # Handle input defaults @@ -1164,14 +1163,13 @@ def param_dataframe(self, sort_by_model: bool = False) -> pd.DataFrame: df.loc[k, "val"] = self.param_val(k) df.loc[k, "unc"] = self.param_unc(k) if sort_by_model: - df.set_index( + df = df.set_index( pd.MultiIndex.from_tuples( # Only split on the first underscore, in case the param name # has an underscore in it [tuple(p.split("_", maxsplit=1)) for p in df.index], names=["model", "param"], ), - inplace=True, ) return df @@ -1365,7 +1363,7 @@ def custom_plot( # Residuals # --------- y_eval = self.eval(self.x_roi, **self.best_values) * dx_roi - res_kwargs = dict(fmt="o", color="k", markersize=5, label="residuals") + res_kwargs = {"fmt": "o", "color": "k", "markersize": 5, "label": "residuals"} # Y-values of the residual plot, depending on residual_type y_plot = self.compute_residuals(residual_type) @@ -1391,10 +1389,13 @@ def custom_plot( # Fit report (txt_ax) # ------------------- if enable_fit_panel: - props = dict( - boxstyle="round", facecolor="white", edgecolor="black", alpha=1 - ) - props = dict(facecolor="white", edgecolor="none", alpha=0) + props = { + "boxstyle": "round", + "facecolor": "white", + "edgecolor": "black", + "alpha": 1, + } + props = {"facecolor": "white", "edgecolor": "none", "alpha": 0} fp = FontProperties(family="monospace", size=8) if "lmfit" in self.backend: best_fit_values = "" @@ -1408,12 +1409,9 @@ def custom_plot( # "stderr calculated.", FittingWarning) else: best_fit_values += ( - "{:15} {: .6e} +/- {:.5e} ({:6.1%})\n".format( - p, - op[p].value, - op[p].stderr, - abs(op[p].stderr / op[p].value), - ) + f"{p:15} " + f"{op[p].value: .6e} +/- {op[p].stderr:.5e} " + f"({abs(op[p].stderr / op[p].value):6.1%})\n" ) best_fit_values += "{:15} {: .6e}\n".format( "Chi Squared:", self.result.chisqr @@ -1431,8 +1429,10 @@ def custom_plot( for (_, param_name), param_data in sdf.iterrows(): v = param_data["val"] e = param_data["unc"] - s += " {:24}: {: .6e} +/- {:.5e} ({:6.1%})\n".format( - param_name, v, e, np.abs(e / v) + s += ( + f" {param_name:24}: " + f"{v: .6e} +/- {e:.5e} " + f"({np.abs(e / v):6.1%})\n" ) elif "minuit" in self.backend: s = str(self.result) + "\n" diff --git a/becquerel/core/peakfinder.py b/becquerel/core/peakfinder.py index a64605fd..c3ebf82a 100644 --- a/becquerel/core/peakfinder.py +++ b/becquerel/core/peakfinder.py @@ -197,9 +197,8 @@ def sort_by(self, arr): """Sort peaks by the provided array.""" if len(arr) != len(self.centroids): raise PeakFinderError( - "Sorting array has length {} but must have length {}".format( - len(arr), len(self.centroids) - ) + f"Sorting array has length {len(arr)} " + f"but must have length {len(self.centroids)}" ) self.centroids = np.array(self.centroids) self.snrs = np.array(self.snrs) @@ -331,9 +330,8 @@ def find_peak(self, xpeak, frac_range=(0.8, 1.2), min_snr=2): raise PeakFinderError(f"Minimum SNR {min_snr:.3f} must be > 0") if self.snr.max() < min_snr: raise PeakFinderError( - "SNR threshold is {:.3f} but maximum SNR is {:.3f}".format( - min_snr, self.snr.max() - ) + f"SNR threshold is {min_snr:.3f} " + f"but maximum SNR is {self.snr.max():.3f}" ) x0 = frac_range[0] * xpeak x1 = frac_range[1] * xpeak @@ -393,9 +391,8 @@ def find_peaks(self, xmin=None, xmax=None, min_snr=2, max_num=40, reset=False): raise PeakFinderError(f"Minimum SNR {min_snr:.3f} must be > 0") if self.snr.max() < min_snr: raise PeakFinderError( - "SNR threshold is {:.3f} but maximum SNR is {:.3f}".format( - min_snr, self.snr.max() - ) + f"SNR threshold is {min_snr:.3f} " + f"but maximum SNR is {self.snr.max():.3f}" ) max_num = int(max_num) if max_num < 1: diff --git a/becquerel/core/plotting.py b/becquerel/core/plotting.py index 4b3a1784..df626e99 100644 --- a/becquerel/core/plotting.py +++ b/becquerel/core/plotting.py @@ -98,8 +98,8 @@ def xmode(self, mode): if mode.lower() in ("kev", "energy"): if not self.spec.is_calibrated: raise PlottingError( - "Spectrum is not calibrated, however" - " x axis was requested as energy" + "Spectrum is not calibrated, however " + "x axis was requested as energy" ) self._xmode = "energy" elif mode.lower() in ("channel", "channels", "chn", "chns"): @@ -275,7 +275,7 @@ def plot(self, *fmt, **kwargs): if hasattr(fmt, "__len__") and len(fmt) > 0: self.fmt = fmt - if not hasattr(self.fmt, "__len__") or not len(self.fmt) in [0, 1]: + if not hasattr(self.fmt, "__len__") or len(self.fmt) not in [0, 1]: raise PlottingError("Wrong number of positional argument") xcorners, ycorners = self._prepare_plot(**kwargs) diff --git a/becquerel/core/rebin.py b/becquerel/core/rebin.py index b009faec..dedc5a45 100644 --- a/becquerel/core/rebin.py +++ b/becquerel/core/rebin.py @@ -256,7 +256,11 @@ def _rebin_listmode(in_spectrum, in_edges, out_edges_no_rightmost, out_spectrum) # knock out leftmost bin edge too, because we put all overflows into # first and last bins anyways out_edges = np.concatenate( - (np.array([-np.inf]), out_edges_no_rightmost[1:], np.array([np.inf])) + ( + np.array([-np.inf]), + out_edges_no_rightmost[1:], + np.array([np.inf]), + ) ) energies = np.zeros(np.sum(in_spectrum)) energy_idx_start = 0 @@ -349,8 +353,8 @@ def rebin( else: warnings.warn( "Argument in_spectra contains float value(s) which " - + "will have decimal precision loss when converting to " - + "integers for rebin method listmode.", + "will have decimal precision loss when converting to " + "integers for rebin method listmode.", RebinWarning, ) in_spectra = in_spectra_rint @@ -384,7 +388,11 @@ def rebin( "Highest output edge must be finite if not including overflows" ) out_edges_temp = np.concatenate( - (np.array([-np.inf]), out_edges, np.array([np.inf])) + ( + np.array([-np.inf]), + out_edges, + np.array([np.inf]), + ) ) else: out_edges_temp = out_edges diff --git a/becquerel/core/spectrum.py b/becquerel/core/spectrum.py index cb1c33ab..51f343a0 100644 --- a/becquerel/core/spectrum.py +++ b/becquerel/core/spectrum.py @@ -179,9 +179,8 @@ def __init__( if self.livetime is not None: if self.livetime > self.realtime: raise ValueError( - "Livetime ({}) cannot exceed realtime ({})".format( - self.livetime, self.realtime - ) + f"Livetime ({self.livetime}) cannot exceed realtime " + f"({self.realtime})" ) self.start_time = handle_datetime(start_time, "start_time", allow_none=True) @@ -194,14 +193,13 @@ def __init__( ): raise SpectrumError( "Specify no more than 2 out of 3 args: " - + "realtime, stop_time, start_time" + "realtime, stop_time, start_time" ) elif self.start_time is not None and self.stop_time is not None: if self.start_time > self.stop_time: raise ValueError( - "Stop time ({}) must be after start time ({})".format( - self.start_time, self.stop_time - ) + f"Stop time ({self.start_time}) must be after start time " + f"({self.stop_time})" ) self.realtime = (self.stop_time - self.start_time).total_seconds() elif self.start_time is not None and self.realtime is not None: @@ -219,9 +217,16 @@ def __init__( def __str__(self) -> str: lines = ["becquerel.Spectrum"] - ltups = [] - for k in ["start_time", "stop_time", "realtime", "livetime", "is_calibrated"]: - ltups.append((k, getattr(self, k))) + ltups = [ + (k, getattr(self, k)) + for k in [ + "start_time", + "stop_time", + "realtime", + "livetime", + "is_calibrated", + ] + ] ltups.append(("num_bins", len(self.bin_indices))) if self._counts is None: ltups.append(("gross_counts", None)) @@ -235,8 +240,7 @@ def __str__(self) -> str: ltups.append(("filename", self.attrs["infilename"])) else: ltups.append(("filename", None)) - for lt in ltups: - lines.append(" {:15} {}".format(f"{lt[0]}:", lt[1])) + lines += [f" {lt[0] + ':':15s} {lt[1]}" for lt in ltups] return "\n".join(lines) __repr__ = __str__ @@ -260,10 +264,10 @@ def counts(self) -> np.ndarray: else: try: return self.cps * self.livetime - except TypeError: + except TypeError as exc: raise SpectrumError( "Unknown livetime; cannot calculate counts from CPS" - ) + ) from exc @property def counts_vals(self) -> np.ndarray: @@ -304,10 +308,10 @@ def cps(self) -> np.ndarray: else: try: return self.counts / self.livetime - except TypeError: + except TypeError as exc: raise SpectrumError( "Unknown livetime; cannot calculate CPS from counts" - ) + ) from exc @property def cps_vals(self) -> np.ndarray: @@ -795,8 +799,8 @@ def __add__(self, other): if (self._counts is None) ^ (other._counts is None): raise SpectrumError( "Addition of counts-based and CPS-based spectra is " - + "ambiguous, use Spectrum(counts=specA.counts+specB.counts) " - + "or Spectrum(cps=specA.cps+specB.cps) instead." + "ambiguous, use Spectrum(counts=specA.counts+specB.counts) " + "or Spectrum(cps=specA.cps+specB.cps) instead." ) if self._counts is not None and other._counts is not None: @@ -806,7 +810,7 @@ def __add__(self, other): else: warnings.warn( "Addition of counts with missing livetimes, " - + "livetime was set to None.", + "livetime was set to None.", SpectrumWarning, ) else: @@ -850,7 +854,7 @@ def __sub__(self, other): if (self._cps is None) or (other._cps is None): warnings.warn( "Subtraction of counts-based specta, spectra " - + "have been converted to CPS", + "have been converted to CPS", SpectrumWarning, ) except SpectrumError: @@ -859,14 +863,14 @@ def __sub__(self, other): kwargs["uncs"] = [np.nan] * len(self) warnings.warn( "Subtraction of counts-based spectra, " - + "livetimes have been ignored.", + "livetimes have been ignored.", SpectrumWarning, ) - except SpectrumError: + except SpectrumError as exc: raise SpectrumError( "Subtraction of counts and CPS-based spectra without" - + "livetimes not possible" - ) + "livetimes not possible" + ) from exc if self.is_calibrated and other.is_calibrated: spect_obj = self.__class__(bin_edges_kev=self.bin_edges_kev, **kwargs) @@ -896,14 +900,14 @@ def _add_sub_error_checking(self, other): if self.is_calibrated ^ other.is_calibrated: raise SpectrumError( "Cannot add/subtract uncalibrated spectrum to/from a " - + "calibrated spectrum. If both have the same calibration, " - + 'please use the "calibrate_like" method' + "calibrated spectrum. If both have the same calibration, " + 'please use the "calibrate_like" method' ) if self.is_calibrated and other.is_calibrated: if not np.all(self.bin_edges_kev == other.bin_edges_kev): raise NotImplementedError( "Addition/subtraction for arbitrary calibrated spectra " - + "not implemented" + "not implemented" ) # TODO: if both spectra are calibrated but with different # calibrations, should one be rebinned to match? @@ -911,7 +915,7 @@ def _add_sub_error_checking(self, other): if not np.all(self.bin_edges_raw == other.bin_edges_raw): raise NotImplementedError( "Addition/subtraction for arbitrary uncalibrated " - + "spectra not implemented" + "spectra not implemented" ) def __mul__(self, other): @@ -966,8 +970,10 @@ def _mul_div(self, scaling_factor: float, div=False): if not isinstance(scaling_factor, UFloat): try: scaling_factor = float(scaling_factor) - except (TypeError, ValueError): - raise TypeError("Spectrum must be multiplied/divided by a scalar") + except (TypeError, ValueError) as exc: + raise TypeError( + "Spectrum must be multiplied/divided by a scalar" + ) from exc if ( scaling_factor == 0 or np.isinf(scaling_factor) @@ -1176,7 +1182,7 @@ def find_bin_index(self, x: float, use_kev=None) -> int: "Cannot access energy bins with an uncalibrated Spectrum." ) - bin_edges, bin_widths, _ = self.get_bin_properties(use_kev) + bin_edges, _, _ = self.get_bin_properties(use_kev) x = np.asarray(x) if np.any(x < bin_edges[0]): @@ -1346,7 +1352,7 @@ def rebin( if (self._counts is None) and (self.livetime is not None): warnings.warn( "Rebinning by listmode method without explicit counts " - + "provided in Spectrum object", + "provided in Spectrum object", SpectrumWarning, ) out_spec = rebin( @@ -1473,9 +1479,7 @@ def plot(self, *fmt, **kwargs): elif emode == "bars" or emode == "bar": plotter.errorbar(color=color, label="_nolegend_") elif emode != "none": - raise SpectrumError( - "Unknown error mode '{}', use 'bars' " "or 'band'".format(emode) - ) + raise SpectrumError(f"Unknown error mode '{emode}', use 'bars' or 'band'") return ax def fill_between(self, **kwargs): @@ -1545,8 +1549,8 @@ def fit( Fitter """ - xedges, xlabel = self.parse_xmode(xmode) - ydata, yuncs, ylabel = self.parse_ymode(ymode) + xedges, _ = self.parse_xmode(xmode) + ydata, yuncs, _ = self.parse_ymode(ymode) xcenters = bin_centers_from_edges(xedges) diff --git a/becquerel/core/utils.py b/becquerel/core/utils.py index 491526db..1c43ebcd 100644 --- a/becquerel/core/utils.py +++ b/becquerel/core/utils.py @@ -70,7 +70,7 @@ def handle_uncs(x_array, x_uncs, default_unc_func): elif ufloats: raise UncertaintiesError( "Specify uncertainties with UFloats or " - + "by separate argument, but not both" + "by separate argument, but not both" ) elif x_uncs is not None: return unumpy.uarray(x_array, x_uncs) diff --git a/becquerel/io/h5.py b/becquerel/io/h5.py index ecbfe868..8c573bf7 100644 --- a/becquerel/io/h5.py +++ b/becquerel/io/h5.py @@ -1,7 +1,8 @@ """Simple tools to perform HDF5 I/O.""" +from __future__ import annotations + import pathlib -from typing import Tuple, Union import h5py @@ -44,7 +45,7 @@ class open_h5: """Context manager to allow I/O given HDF5 filename, File, or Group.""" def __init__( - self, name: Union[str, pathlib.Path, h5py.File, h5py.Group], mode=None, **kwargs + self, name: str | pathlib.Path | h5py.File | h5py.Group, mode=None, **kwargs ): """Initialize the context manager. @@ -86,7 +87,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.file.close() -def write_h5(name: Union[str, h5py.File, h5py.Group], dsets: dict, attrs: dict) -> None: +def write_h5(name: str | h5py.File | h5py.Group, dsets: dict, attrs: dict) -> None: """Write the datasets and attributes to an HDF5 file or group. Parameters @@ -114,7 +115,7 @@ def write_h5(name: Union[str, h5py.File, h5py.Group], dsets: dict, attrs: dict) file.attrs.update(attrs) -def read_h5(name: Union[str, h5py.File, h5py.Group]) -> Tuple[dict, dict, list]: +def read_h5(name: str | h5py.File | h5py.Group) -> tuple[dict, dict, list]: """Read the datasets and attributes from an HDF5 file or group. Parameters diff --git a/becquerel/parsers/__init__.py b/becquerel/parsers/__init__.py index a0b4121f..62b13c05 100644 --- a/becquerel/parsers/__init__.py +++ b/becquerel/parsers/__init__.py @@ -6,9 +6,9 @@ __all__ = [ "BecquerelParserError", "BecquerelParserWarning", - "h5", "cnf", + "h5", + "iec1455", "spc", "spe", - "iec1455", ] diff --git a/becquerel/parsers/cnf.py b/becquerel/parsers/cnf.py index ccb44956..6c45dd0c 100644 --- a/becquerel/parsers/cnf.py +++ b/becquerel/parsers/cnf.py @@ -56,7 +56,7 @@ def _from_pdp11(data, index): exb = ((data[index + 1] & 0x7F) << 1) + ((data[index] & 0x80) >> 7) if exb == 0: if sign == -1: - return np.NaN + return np.nan else: return 0.0 h = ( @@ -104,7 +104,7 @@ def read(filename, verbose=False, cal_kwargs=None): # read all of the file into memory file_bytes = [] - with open(filename, "rb") as f: + with Path(filename).open("rb") as f: byte = f.read(1) while byte: byte_int = struct.unpack("1B", byte) @@ -278,7 +278,7 @@ def read(filename, verbose=False, cal_kwargs=None): raise BecquerelParserError("Channel data not found") channels = np.array([], dtype=float) counts = np.array([], dtype=float) - for i in range(0, 2): + for i in range(2): y = _from_little_endian(file_bytes, offset_chan + 512 + 4 * i, 4) if y == int(realtime) or y == int(livetime): y = 0 diff --git a/becquerel/parsers/iec1455.py b/becquerel/parsers/iec1455.py index 683030ee..6d4839ce 100644 --- a/becquerel/parsers/iec1455.py +++ b/becquerel/parsers/iec1455.py @@ -95,7 +95,7 @@ def read(filename, verbose=False, cal_kwargs=None): energy_channel_pairs = [] energy_res_pairs = [] # currently unused energy_eff_pairs = [] # currently unused - with open(filename) as f: + with Path(filename).open() as f: record = 1 # loop over lines for line in f: @@ -154,14 +154,12 @@ def read(filename, verbose=False, cal_kwargs=None): # fix nasty formatting of Interwinner export line = re.sub(r"(e[+-][0-9]{2})", r"\1 ", line.lower()) tok = line.split() - for coeff in tok: - cal_coeff.append(float(coeff)) + cal_coeff += [float(coeff) for coeff in tok] elif record == 5: # fix nasty formatting of Interwinner export line = re.sub(r"(e[+-][0-9]{2})", r"\1 ", line.lower()) tok = line.split() - for coeff in tok: - fwhm_coeff.append(float(coeff)) + fwhm_coeff += [float(coeff) for coeff in tok] elif record >= 6 and record <= 9: line = line.strip() if len(line) > 0: @@ -176,8 +174,7 @@ def read(filename, verbose=False, cal_kwargs=None): elif record >= 35 and record <= 46: energy_eff_pairs += _read_nonzero_number_pairs(tok) elif record >= 59: - for c in tok[1:]: # skip channel index - counts.append(int(c)) + counts += [int(c) for c in tok[1:]] # skip channel index # increment record number record += 1 diff --git a/becquerel/parsers/spc.py b/becquerel/parsers/spc.py index b3894061..c319be9e 100644 --- a/becquerel/parsers/spc.py +++ b/becquerel/parsers/spc.py @@ -217,7 +217,7 @@ def read(filename, verbose=False, cal_kwargs=None): # initialize a dictionary of spectrum data to populate as we parse data = {} - with open(filename, "rb") as f: + with Path(filename).open("rb") as f: # read the file in chunks of 128 bytes data_records = [] binary_data = None @@ -226,8 +226,10 @@ def read(filename, verbose=False, cal_kwargs=None): data_records.append(binary_data) try: binary_data = f.read(128) - except OSError: - raise BecquerelParserError("Unable to read 128 bytes from file") + except OSError as exc: + raise BecquerelParserError( + "Unable to read 128 bytes from file" + ) from exc if len(binary_data) < 128: break if verbose: @@ -250,18 +252,18 @@ def read(filename, verbose=False, cal_kwargs=None): for data_format in record_format: fmt += data_format[1] if verbose: - print("") - print("") + print() + print() print("-" * 60) - print("") + print() print(record_format) print(fmt) - print("") + print() binary_data = struct.unpack(fmt, binary_data) if verbose: - print("") + print() print(binary_data) - print("") + print() for j, data_format in enumerate(record_format): if isinstance(binary_data[j], bytes): data[data_format[0]] = binary_data[j].decode("ascii") @@ -298,18 +300,18 @@ def read(filename, verbose=False, cal_kwargs=None): for data_format in record_format: fmt += data_format[1] if verbose: - print("") - print("") + print() + print() print("-" * 60) - print("") + print() print(record_format) print(fmt) - print("") + print() binary_data = struct.unpack(fmt, binary_data) if verbose: - print("") + print() print(binary_data) - print("") + print() for j, data_format in enumerate(record_format): if isinstance(binary_data[j], bytes): data[data_format[0]] = binary_data[j].decode("ascii") @@ -386,8 +388,8 @@ def read(filename, verbose=False, cal_kwargs=None): float(data["Calibration parameter 1"]), float(data["Calibration parameter 2"]), ] - except KeyError: - raise BecquerelParserError("Calibration parameters not found") + except KeyError as exc: + raise BecquerelParserError("Calibration parameters not found") from exc # clean up null characters in any strings for key in data.keys(): diff --git a/becquerel/parsers/spe.py b/becquerel/parsers/spe.py index 6899b988..1f4ebffb 100644 --- a/becquerel/parsers/spe.py +++ b/becquerel/parsers/spe.py @@ -48,9 +48,9 @@ def read(filename, verbose=False, cal_kwargs=None): counts = [] channels = [] cal_coeff = [] - with open(filename) as f: + with Path(filename).open() as f: # read & remove newlines from end of each line - lines = [line.strip() for line in f.readlines()] + lines = [line.strip() for line in f] i = 0 while i < len(lines): # check whether we have reached a keyword and parse accordingly @@ -86,8 +86,8 @@ def read(filename, verbose=False, cal_kwargs=None): i += 1 n_coeff = int(lines[i]) i += 1 - for j in range(n_coeff): - cal_coeff.append(float(lines[i].split(" ")[j])) + tokens = lines[i].split(" ") + cal_coeff += [float(token) for token in tokens[:n_coeff]] if verbose: print(cal_coeff) elif lines[i].startswith("$"): diff --git a/becquerel/tools/__init__.py b/becquerel/tools/__init__.py index bf6ddd8b..2b829296 100644 --- a/becquerel/tools/__init__.py +++ b/becquerel/tools/__init__.py @@ -43,43 +43,43 @@ ) __all__ = [ - "xcom", - "nndc", - "materials", - "element", - "isotope", - "isotope_qty", + "MIXTURE_AIR_DRY", + "MIXTURE_PORTLAND_CEMENT", + "MIXTURE_SEAWATER", + "N_AV", + "UCI_TO_BQ", "Element", "ElementError", - "ElementZError", - "ElementSymbolError", "ElementNameError", + "ElementSymbolError", + "ElementZError", "Isotope", "IsotopeError", "IsotopeQuantity", "IsotopeQuantityError", - "NeutronIrradiation", - "NeutronIrradiationError", - "UCI_TO_BQ", - "N_AV", - "force_load_and_write_materials_csv", - "fetch_materials", - "remove_materials_csv", "MaterialsError", "MaterialsWarning", "NISTMaterialsError", "NISTMaterialsRequestError", - "fetch_wallet_card", - "fetch_decay_radiation", "NNDCError", - "NoDataFound", "NNDCInputError", "NNDCRequestError", - "fetch_xcom_data", + "NeutronIrradiation", + "NeutronIrradiationError", + "NoDataFound", "XCOMError", "XCOMInputError", "XCOMRequestError", - "MIXTURE_AIR_DRY", - "MIXTURE_SEAWATER", - "MIXTURE_PORTLAND_CEMENT", + "element", + "fetch_decay_radiation", + "fetch_materials", + "fetch_wallet_card", + "fetch_xcom_data", + "force_load_and_write_materials_csv", + "isotope", + "isotope_qty", + "materials", + "nndc", + "remove_materials_csv", + "xcom", ] diff --git a/becquerel/tools/df_cache.py b/becquerel/tools/df_cache.py index ba0ad6a5..6f1bf703 100644 --- a/becquerel/tools/df_cache.py +++ b/becquerel/tools/df_cache.py @@ -1,6 +1,6 @@ """A simple class for caching a pandas DataFrame.""" -import os +from pathlib import Path import pandas as pd @@ -42,9 +42,9 @@ def __init__(self): """ if self.path is None: - self.path = os.path.split(__file__)[0] + self.path = Path(__file__).parent self.check_path() - self.filename = os.path.join(self.path, "__df_cache__" + self.name + ".csv") + self.filename = self.path / ("__df_cache__" + self.name + ".csv") self.df = None self.loaded = False @@ -55,9 +55,10 @@ def check_path(self): CacheError: if the path does not exist. """ - if not os.path.exists(self.path): + self.path = Path(self.path) + if not self.path.exists(): raise CacheError(f"Cache path does not exist: {self.path}") - if not os.path.isdir(self.path): + if not self.path.is_dir(): raise CacheError(f"Cache path is not a directory: {self.path}") def check_file(self): @@ -67,9 +68,10 @@ def check_file(self): CacheError: if the file does not exist. """ - if not os.path.exists(self.filename): + self.filename = Path(self.filename) + if not self.filename.exists(): raise CacheError(f"Cache filename does not exist: {self.filename}") - if not os.path.isfile(self.filename): + if not self.filename.is_file(): raise CacheError(f"Cache filename is not a file: {self.filename}") def write_file(self): @@ -84,8 +86,8 @@ def write_file(self): raise CacheError("Cache has not been fetched or loaded") try: self.df.to_csv(self.filename, float_format="%.12f") - except Exception: - raise CacheError(f"Problem writing cache to file {self.filename}") + except Exception as exc: + raise CacheError(f"Problem writing cache to file {self.filename}") from exc self.check_file() def read_file(self): @@ -98,8 +100,10 @@ def read_file(self): self.check_file() try: self.df = pd.read_csv(self.filename) - except Exception: - raise CacheError(f"Problem reading cache from file {self.filename}") + except Exception as exc: + raise CacheError( + f"Problem reading cache from file {self.filename}" + ) from exc self.loaded = True def delete_file(self): @@ -111,9 +115,9 @@ def delete_file(self): self.check_file() try: - os.remove(self.filename) - except Exception: - raise CacheError(f"Problem deleting cache file {self.filename}") + self.filename.unlink() + except Exception as exc: + raise CacheError(f"Problem deleting cache file {self.filename}") from exc try: self.check_file() except CacheError: @@ -143,8 +147,8 @@ def load(self): except CacheError: try: self.fetch() - except CacheError: - raise CacheError("Cannot read or download DataFrame") + except CacheError as exc: + raise CacheError("Cannot read or download DataFrame") from exc self.write_file() self.read_file() self.loaded = True diff --git a/becquerel/tools/element.py b/becquerel/tools/element.py index c6c4dda8..a2f30561 100644 --- a/becquerel/tools/element.py +++ b/becquerel/tools/element.py @@ -183,8 +183,8 @@ def validated_z(z): try: int(z) - except ValueError: - raise ElementZError(f'Element Z="{z}" invalid') + except ValueError as exc: + raise ElementZError(f'Element Z="{z}" invalid') from exc if int(z) not in ZS: raise ElementZError(f'Element Z="{z}" not found') return int(z) @@ -205,8 +205,8 @@ def validated_symbol(sym): try: sym.lower() - except AttributeError: - raise ElementSymbolError(f'Element symbol "{sym}" invalid') + except AttributeError as exc: + raise ElementSymbolError(f'Element symbol "{sym}" invalid') from exc if sym.lower() not in SYMBOLS_LOWER: raise ElementSymbolError(f'Element symbol "{sym}" not found') return _SYMBOL_FROM_SYMBOL_LOWER[sym.lower()] @@ -227,8 +227,8 @@ def validated_name(nm): try: nm.lower() - except AttributeError: - raise ElementNameError(f'Element name "{nm}" invalid') + except AttributeError as exc: + raise ElementNameError(f'Element name "{nm}" invalid') from exc # special case: Alumin[i]um if nm.lower() == "aluminium": nm = "Aluminum" @@ -259,8 +259,8 @@ def element_z(sym_or_name): pass try: return _Z_FROM_NAME[validated_name(sym_or_name)] - except ElementNameError: - raise ElementZError("Must supply either the element symbol or name") + except ElementNameError as exc: + raise ElementZError("Must supply either the element symbol or name") from exc def element_symbol(name_or_z): @@ -282,8 +282,8 @@ def element_symbol(name_or_z): pass try: return _SYMBOL_FROM_NAME[validated_name(name_or_z)] - except ElementNameError: - raise ElementSymbolError("Must supply either the Z or element name") + except ElementNameError as exc: + raise ElementSymbolError("Must supply either the Z or element name") from exc def element_name(sym_or_z): @@ -305,16 +305,16 @@ def element_name(sym_or_z): pass try: return _NAME_FROM_Z[validated_z(sym_or_z)] - except ElementZError: - raise ElementNameError("Must supply either the element symbol or Z") + except ElementZError as exc: + raise ElementNameError("Must supply either the element symbol or Z") from exc class Element: """Basic properties (symbol, name, Z, and mass) of an element. Also provides string formatting: - >>> elem = Element('Ge') - >>> '{:%n(%s) Z=%z}'.format(elem) + >>> elem = Element("Ge") + >>> "{:%n(%s) Z=%z}".format(elem) 'Germanium(Ge) Z=32' Properties: @@ -346,8 +346,8 @@ def __init__(self, arg): except ElementNameError: try: self._init_z(arg) - except ElementZError: - raise ElementError(f"Could not instantiate Element: {arg}") + except ElementZError as exc: + raise ElementError(f"Could not instantiate Element: {arg}") from exc self.atomic_mass = _MASS_FROM_SYMBOL[self.symbol] def _init_sym(self, arg): @@ -397,5 +397,5 @@ def __eq__(self, other): and self.symbol == other.symbol and self.Z == other.Z ) - except Exception: - raise ElementError("Cannot determine equality") + except Exception as exc: + raise ElementError("Cannot determine equality") from exc diff --git a/becquerel/tools/isotope.py b/becquerel/tools/isotope.py index 22076fb3..6ba615a1 100644 --- a/becquerel/tools/isotope.py +++ b/becquerel/tools/isotope.py @@ -84,8 +84,8 @@ def _split_element_mass(arg): # ensure element name or symbol is valid try: element.Element(element_id) - except element.ElementError: - raise IsotopeError(f"Element name or symbol is invalid: {element_id}") + except element.ElementError as exc: + raise IsotopeError(f"Element name or symbol is invalid: {element_id}") from exc return element_id, mass_isomer @@ -115,10 +115,10 @@ def _split_mass_isomer(arg): raise IsotopeError(f"Too many ms in mass number: {arg} {tokens}") try: aa = int(tokens[0]) - except ValueError: + except ValueError as exc: raise IsotopeError( f"Mass number cannot be converted to int: {tokens[0]} {arg}" - ) + ) from exc mm = "m" if len(tokens[1]) > 0: if not tokens[1].isdigit(): @@ -129,8 +129,10 @@ def _split_mass_isomer(arg): else: try: aa = int(arg) - except ValueError: - raise IsotopeError(f"Mass number cannot be converted to int: {arg}") + except ValueError as exc: + raise IsotopeError( + f"Mass number cannot be converted to int: {arg}" + ) from exc return (aa, mm) @@ -157,8 +159,8 @@ class Isotope(element.Element): """Basic properties of a nuclear isotope, including isomers. Also provides string formatting: - >>> iso = Isotope('178M2HF') - >>> '{:%n(%s)-%a%m Z=%z A=%a}'.format(iso) + >>> iso = Isotope("178M2HF") + >>> "{:%n(%s)-%a%m Z=%z A=%a}".format(iso) 'Hafnium(Hf)-178m2 Z=72 A=178' Properties (read-only): @@ -206,8 +208,8 @@ def __init__(self, *args): elif len(args) == 2 or len(args) == 3: try: super().__init__(args[0]) - except element.ElementError: - raise IsotopeError(f"Unable to create Isotope: {args}") + except element.ElementError as exc: + raise IsotopeError(f"Unable to create Isotope: {args}") from exc self._init_A(args[1]) if len(args) == 3: self._init_m(args[2]) @@ -221,8 +223,10 @@ def _init_A(self, arg): """Initialize with an isotope A.""" try: self.A = int(arg) - except ValueError: - raise IsotopeError(f"Mass number cannot be converted to integer: {arg}") + except ValueError as exc: + raise IsotopeError( + f"Mass number cannot be converted to integer: {arg}" + ) from exc if self.A < 1: raise IsotopeError(f"Mass number must be >= 1: {self.A}") @@ -250,18 +254,15 @@ def _init_m(self, arg): if len(self.m) > 1: if not self.m[1:].isdigit(): raise IsotopeError( - "Metastable level must be numeric: {} {}".format( - self.m[0], self.m[1:] - ) + "Metastable level must be numeric: " + f"{self.m[0]} {self.m[1:]}" ) self.M = int(self.m[1:]) else: self.M = 1 else: raise IsotopeError( - "Metastable level must be integer or string: {} {}".format( - arg, type(arg) - ) + f"Metastable level must be integer or string: {arg} {type(arg)}" ) def __str__(self): diff --git a/becquerel/tools/isotope_qty.py b/becquerel/tools/isotope_qty.py index b51685a4..24be1575 100644 --- a/becquerel/tools/isotope_qty.py +++ b/becquerel/tools/isotope_qty.py @@ -164,23 +164,23 @@ def _quantities_from_kwargs(self, **kwargs): # dictionary with functions that define how to calculate all quantities # in a circular manner if self.is_stable: - conversions = dict( - atoms=lambda: ref_quantities["g"] / self.isotope.A * N_AV, - g=lambda: ref_quantities["atoms"] * self.isotope.A / N_AV, - bq=lambda: 0, - uci=lambda: 0, - ) + conversions = { + "atoms": lambda: ref_quantities["g"] / self.isotope.A * N_AV, + "g": lambda: ref_quantities["atoms"] * self.isotope.A / N_AV, + "bq": lambda: 0, + "uci": lambda: 0, + } else: - conversions = dict( - atoms=lambda: ref_quantities["g"] / self.isotope.A * N_AV, - bq=lambda: ref_quantities["atoms"] * self.decay_const, - uci=lambda: ref_quantities["bq"] / UCI_TO_BQ, - g=lambda: ref_quantities["uci"] + conversions = { + "atoms": lambda: ref_quantities["g"] / self.isotope.A * N_AV, + "bq": lambda: ref_quantities["atoms"] * self.decay_const, + "uci": lambda: ref_quantities["bq"] / UCI_TO_BQ, + "g": lambda: ref_quantities["uci"] * UCI_TO_BQ / self.decay_const / N_AV * self.isotope.A, - ) + } # rotates the order of the list so that the provided kwarg is at [0] order = ["atoms", "bq", "uci", "g"] @@ -238,9 +238,7 @@ def from_decays(cls, isotope, n_decays, start_time, stop_time): duration = (stop_time - obj.ref_date).total_seconds() if duration < 0: raise ValueError( - "Start time must precede stop time: {}, {}".format( - start_time, stop_time - ) + f"Start time must precede stop time: {start_time}, {stop_time}" ) atoms = n_decays / (-np.expm1(-obj.decay_const * duration)) @@ -506,9 +504,7 @@ def __str__(self): if self.isotope.is_stable: s = f"{self.g_at(self.ref_date)} g of {self.isotope}" else: - s = "{} Bq of {} (at {})".format( - self.bq_at(self.ref_date), self.isotope, self.ref_date - ) + s = f"{self.bq_at(self.ref_date)} Bq of {self.isotope} (at {self.ref_date})" return s def __mul__(self, other): @@ -605,9 +601,7 @@ def __init__(self, start_time, stop_time, n_cm2=None, n_cm2_s=None): ) if self.stop_time < self.start_time: raise ValueError( - "Timestamps out of order: {}, {}".format( - self.start_time, self.stop_time - ) + f"Timestamps out of order: {self.start_time}, {self.stop_time}" ) self.duration = (self.stop_time - self.start_time).total_seconds() @@ -632,9 +626,7 @@ def __str__(self): if self.duration == 0: return f"{self.n_cm2} neutrons/cm2 at {self.start_time}" else: - return "{} n/cm2/s from {} to {}".format( - self.n_cm2_s, self.start_time, self.stop_time - ) + return f"{self.n_cm2_s} n/cm2/s from {self.start_time} to {self.stop_time}" def activate(self, barns, initial, activated, stability=1e18): """ @@ -688,8 +680,8 @@ def activate(self, barns, initial, activated, stability=1e18): activated, IsotopeQuantity ): raise NeutronIrradiationError( - "Two IsotopeQuantity's in args, nothing left to calculate!" - + f"Args: {initial}, {activated}" + "Two IsotopeQuantity's in args, nothing left to calculate! " + f"Args: {initial}, {activated}" ) elif isinstance(initial, IsotopeQuantity) and isinstance(activated, Isotope): forward = True @@ -698,12 +690,12 @@ def activate(self, barns, initial, activated, stability=1e18): elif isinstance(initial, Isotope) and isinstance(activated, Isotope): raise NeutronIrradiationError( "No IsotopeQuantity specified, not enough data. " - + f"Args: {initial}, {activated}" + f"Args: {initial}, {activated}" ) else: raise TypeError( "Input args should be Isotope or IsotopeQuantity objects: " - + f"{initial}, {activated}" + f"{initial}, {activated}" ) if not initial.half_life > stability: diff --git a/becquerel/tools/materials.py b/becquerel/tools/materials.py index bbd45c2c..430755fa 100644 --- a/becquerel/tools/materials.py +++ b/becquerel/tools/materials.py @@ -1,8 +1,8 @@ """Load material data for use in attenuation calculations with XCOM.""" import csv -import os import warnings +from pathlib import Path import numpy as np @@ -10,7 +10,7 @@ from .materials_error import MaterialsError, MaterialsWarning from .materials_nist import fetch_compound_data, fetch_element_data -FILENAME = os.path.join(os.path.split(__file__)[0], "materials.csv") +FILENAME = Path(__file__).parent / "materials.csv" def _load_and_compile_materials(): @@ -28,13 +28,13 @@ def _load_and_compile_materials(): # perform various checks on the Compendium data for j in range(len(data_comp)): - name = data_comp["Material"].values[j] - rho1 = data_comp["Density"].values[j] + name = data_comp["Material"].to_numpy()[j] + rho1 = data_comp["Density"].to_numpy()[j] rho2 = None - if name in data_elem["Element"].values: - rho2 = data_elem["Density"][data_elem["Element"] == name].values[0] - elif name in data_mat["Material"].values: - rho2 = data_mat["Density"][data_mat["Material"] == name].values[0] + if name in data_elem["Element"].to_numpy(): + rho2 = data_elem["Density"][data_elem["Element"] == name].to_numpy()[0] + elif name in data_mat["Material"].to_numpy(): + rho2 = data_mat["Density"][data_mat["Material"] == name].to_numpy()[0] if rho2: if not np.isclose(rho1, rho2, atol=2e-2): raise MaterialsError( @@ -43,12 +43,12 @@ def _load_and_compile_materials(): ) for j in range(len(data_comp)): - name = data_comp["Material"].values[j] - if name in data_mat["Material"].values: - weight_fracs1 = data_comp["Composition_symbol"].values[j] + name = data_comp["Material"].to_numpy()[j] + if name in data_mat["Material"].to_numpy(): + weight_fracs1 = data_comp["Composition_symbol"].to_numpy()[j] weight_fracs2 = data_mat["Composition_symbol"][ data_mat["Material"] == name - ].values[0] + ].to_numpy()[0] if len(weight_fracs1) != len(weight_fracs2): raise MaterialsError( f"Material {name} has different number of weight fractions " @@ -73,36 +73,36 @@ def _load_and_compile_materials(): # make a dictionary of all the materials materials = {} for j in range(len(data_elem)): - name = data_elem["Element"].values[j] - formula = data_elem["Symbol"].values[j] - density = data_elem["Density"].values[j] - weight_fracs = data_elem["Composition_symbol"].values[j] + name = data_elem["Element"].to_numpy()[j] + formula = data_elem["Symbol"].to_numpy()[j] + density = data_elem["Density"].to_numpy()[j] + weight_fracs = data_elem["Composition_symbol"].to_numpy()[j] materials[name] = { "formula": formula, "density": density, "weight_fractions": weight_fracs, - "source": '"NIST (http://physics.nist.gov/PhysRefData/XrayMassCoef/tab1.html)"', # noqa: E501 + "source": "NIST (http://physics.nist.gov/PhysRefData/XrayMassCoef/tab1.html)", } # add duplicate entry under element symbol for backwards compatibility materials[formula] = materials[name] for j in range(len(data_mat)): - name = data_mat["Material"].values[j] + name = data_mat["Material"].to_numpy()[j] formula = "-" - density = data_mat["Density"].values[j] - weight_fracs = data_mat["Composition_symbol"].values[j] + density = data_mat["Density"].to_numpy()[j] + weight_fracs = data_mat["Composition_symbol"].to_numpy()[j] materials[name] = { "formula": formula, "density": density, "weight_fractions": weight_fracs, - "source": '"NIST (http://physics.nist.gov/PhysRefData/XrayMassCoef/tab2.html)"', # noqa: E501 + "source": "NIST (http://physics.nist.gov/PhysRefData/XrayMassCoef/tab2.html)", } for j in range(len(data_comp)): - name = data_comp["Material"].values[j] - formula = data_comp["Formula"].values[j] - density = data_comp["Density"].values[j] - weight_fracs = data_comp["Composition_symbol"].values[j] + name = data_comp["Material"].to_numpy()[j] + formula = data_comp["Formula"].to_numpy()[j] + density = data_comp["Density"].to_numpy()[j] + weight_fracs = data_comp["Composition_symbol"].to_numpy()[j] if name in materials: # replace material formula if compendium has one # otherwise do not overwrite the NIST data @@ -113,12 +113,12 @@ def _load_and_compile_materials(): "density": density, "weight_fractions": weight_fracs, "source": ( - '"Detwiler, Rebecca S., McConn, Ronald J., Grimes, ' + "Detwiler, Rebecca S., McConn, Ronald J., Grimes, " "Thomas F., Upton, Scott A., & Engel, Eric J. Compendium of " "Material Composition Data for Radiation Transport Modeling. " "United States. PNNL-15870 Revision 2., " "https://doi.org/10.2172/1782721 " - '(https://compendium.cwmd.pnnl.gov)"' + "(https://compendium.cwmd.pnnl.gov)" ), } @@ -133,13 +133,13 @@ def _write_materials_csv(materials): materials : dict Dictionary of materials. """ - if os.path.exists(FILENAME): + if FILENAME.exists(): warnings.warn( f"Materials data CSV already exists at {FILENAME} and will be overwritten", MaterialsWarning, ) mat_list = sorted(materials.keys()) - with open(FILENAME, "w") as f: + with FILENAME.open("w") as f: print("%name,formula,density,weight fractions,source", file=f) for name in mat_list: line = "" @@ -158,10 +158,10 @@ def _read_materials_csv(): materials Dictionary keyed by material names containing the material data. """ - if not os.path.exists(FILENAME): + if not FILENAME.exists(): raise MaterialsError(f"Materials data CSV does not exist at {FILENAME}") materials = {} - with open(FILENAME) as f: + with FILENAME.open() as f: lines = f.readlines() for tokens in csv.reader( lines, @@ -218,7 +218,7 @@ def fetch_materials(force=False): materials Dictionary keyed by material names containing the material data. """ - if force or not os.path.exists(FILENAME): + if force or not FILENAME.exists(): materials = force_load_and_write_materials_csv() materials = _read_materials_csv() return materials @@ -226,5 +226,5 @@ def fetch_materials(force=False): def remove_materials_csv(): """Remove materials.csv if it exists.""" - if os.path.exists(FILENAME): - os.remove(FILENAME) + if FILENAME.exists(): + FILENAME.unlink() diff --git a/becquerel/tools/materials_compendium.py b/becquerel/tools/materials_compendium.py index 22e6e51b..e0b337f5 100644 --- a/becquerel/tools/materials_compendium.py +++ b/becquerel/tools/materials_compendium.py @@ -14,15 +14,15 @@ """ import json -import os import warnings +from pathlib import Path import numpy as np import pandas as pd from .materials_error import MaterialsError, MaterialsWarning -FNAME = os.path.join(os.path.split(__file__)[0], "MaterialsCompendium.json") +FNAME = Path(__file__).parent / "MaterialsCompendium.json" def json_elements_to_weight_fractions(elements): @@ -47,7 +47,7 @@ def json_elements_to_atom_fractions(elements): def fetch_compendium_data(): """Read material data from the Compendium.""" # read the file - if not os.path.exists(FNAME): + if not FNAME.exists(): warnings.warn( 'Material data from the "Compendium of Material Composition Data for ' 'Radiation Transport Modeling" cannot be found. If these data are ' @@ -58,7 +58,7 @@ def fetch_compendium_data(): ) data = [] else: - with open(FNAME) as f: + with FNAME.open() as f: data = json.load(f) # extract relevant data diff --git a/becquerel/tools/materials_nist.py b/becquerel/tools/materials_nist.py index 21d9ecbb..89fe3a10 100644 --- a/becquerel/tools/materials_nist.py +++ b/becquerel/tools/materials_nist.py @@ -41,9 +41,8 @@ def _get_request(url): req = requests.get(url, timeout=15) if not req.ok or req.reason != "OK" or req.status_code != 200: raise MaterialsError( - "NIST materials request failed: reason={}, status_code={}".format( - req.reason, req.status_code - ) + f"NIST materials request failed: reason={req.reason}, " + f"status_code={req.status_code}" ) return req @@ -94,7 +93,7 @@ def fetch_element_data(): df.columns = ["Z", "Symbol", "Element", "Z_over_A", "I_eV", "Density"] # add composition by Z - df["Composition_Z"] = [[f"{z}: 1.000000"] for z in df["Z"].values] + df["Composition_Z"] = [[f"{z}: 1.000000"] for z in df["Z"].to_numpy()] # add composition by symbol df["Composition_symbol"] = [ convert_composition(comp) for comp in df["Composition_Z"] @@ -126,12 +125,12 @@ def convert_composition(comp): raise MaterialsError(f"Line must be a string type: {line} {type(line)}") try: z, weight = line.split(":") - except ValueError: - raise MaterialsError(f"Unable to split compound line: {line}") + except ValueError as exc: + raise MaterialsError(f"Unable to split compound line: {line}") from exc try: z = int(z) - except ValueError: - raise MaterialsError(f"Unable to convert Z {z} to integer: {line}") + except ValueError as exc: + raise MaterialsError(f"Unable to convert Z {z} to integer: {line}") from exc if z < 1 or z > MAX_Z: raise MaterialsError(f"Z {z} out of range [1, {line}]: {MAX_Z}") comp_sym.append(element_symbol(z) + " " + weight.strip()) diff --git a/becquerel/tools/nndc.py b/becquerel/tools/nndc.py index 4867bfd1..d2481589 100644 --- a/becquerel/tools/nndc.py +++ b/becquerel/tools/nndc.py @@ -152,8 +152,8 @@ def _parse_headers(headers): if len(set(headers_new)) != len(headers_new): raise NNDCRequestError( "Duplicate headers after parsing\n" - + f' Original headers: "{headers}"\n' - + f' Parsed headers: "{headers_new}"' + f' Original headers: "{headers}"\n' + f' Parsed headers: "{headers_new}"' ) return headers_new @@ -179,7 +179,7 @@ def _parse_table(text): text = text.split("To save this output")[0] lines = text.split("\n") except Exception as exc: - raise NNDCRequestError(f"Unable to parse text:\n{exc}\n{text}") + raise NNDCRequestError(f"Unable to parse text:\n{exc}\n{text}") from exc table = {} headers = None for line in lines: @@ -196,8 +196,8 @@ def _parse_table(text): if len(tokens) != len(headers): raise NNDCRequestError( "Too few data in table row\n" - + f' Headers: "{headers}"\n' - + f' Row: "{tokens}"' + f' Headers: "{headers}"\n' + f' Row: "{tokens}"' ) for header, token in zip(headers, tokens): table[header].append(token) @@ -208,9 +208,9 @@ def _parse_float_uncertainty(x, dx): """Parse a string and its uncertainty into a float or ufloat. Examples: - >>> _parse_float_uncertainty('257.123', '0.005') + >>> _parse_float_uncertainty("257.123", "0.005") 257.123+/-0.005 - >>> _parse_float_uncertainty('8', '') + >>> _parse_float_uncertainty("8", "") 8.0 Args: @@ -268,8 +268,8 @@ def _parse_float_uncertainty(x, dx): dx = "" try: x2 = float(x) - except ValueError: - raise NNDCRequestError(f'Value cannot be parsed as float: "{x}"') + except ValueError as exc: + raise NNDCRequestError(f'Value cannot be parsed as float: "{x}"') from exc if dx == "": return x2 # handle multiple exponents with some uncertainties, e.g., "7E-4E-5" @@ -281,8 +281,10 @@ def _parse_float_uncertainty(x, dx): factor = 1.0 try: dx2 = float(dx) * factor - except ValueError: - raise NNDCRequestError(f'Uncertainty cannot be parsed as float: "{dx}"') + except ValueError as exc: + raise NNDCRequestError( + f'Uncertainty cannot be parsed as float: "{dx}"' + ) from exc return uncertainties.ufloat(x2, dx2) @@ -303,8 +305,10 @@ def _format_range(x_range): try: x1, x2 = x_range - except (TypeError, ValueError): - raise NNDCInputError(f'Range keyword arg must have two elements: "{x_range}"') + except (TypeError, ValueError) as exc: + raise NNDCInputError( + f'Range keyword arg must have two elements: "{x_range}"' + ) from exc try: if np.isfinite(x1): x1 = f"{x1}" @@ -465,7 +469,10 @@ def update(self, **kwargs): if x in kwargs: self._data["spnuc"] = "zanrange" self._data[x + "min"], self._data[x + "max"] = _format_range( - (kwargs[x], kwargs[x]) + ( + kwargs[x], + kwargs[x], + ) ) # handle *_range, *_any, *_odd, *_even elif x + "_range" in kwargs: @@ -524,7 +531,7 @@ def _add_columns_energy_levels(self): # add string m giving the isomer level name (e.g., '' or 'm' or 'm2') self.df["m"] = [""] * len(self) # loop over each isotope in the dataframe - A_Z = [(a, z) for a, z in zip(self["A"], self["Z"])] + A_Z = list(zip(self["A"], self["Z"])) A_Z = set(A_Z) for a, z in A_Z: isotope = (self["A"] == a) & (self["Z"] == z) @@ -554,18 +561,18 @@ def _add_units_uncertainties(self): self._convert_column( "Energy Level", lambda x: _parse_float_uncertainty(x, "") ) - self.df.rename(columns={"Energy Level": "Energy Level (MeV)"}, inplace=True) + self.df = self.df.rename(columns={"Energy Level": "Energy Level (MeV)"}) if "Parent Energy Level" in self.keys(): self._convert_column_uncertainty("Parent Energy Level") - self.df.rename( - columns={"Parent Energy Level": "Energy Level (MeV)"}, inplace=True + self.df = self.df.rename( + columns={"Parent Energy Level": "Energy Level (MeV)"} ) self.df["Energy Level (MeV)"] *= 0.001 if "Mass Excess" in self.keys(): self._convert_column_uncertainty("Mass Excess") - self.df.rename(columns={"Mass Excess": "Mass Excess (MeV)"}, inplace=True) + self.df = self.df.rename(columns={"Mass Excess": "Mass Excess (MeV)"}) self._convert_column("T1/2 (s)", float) @@ -579,14 +586,14 @@ def _add_units_uncertainties(self): if "Radiation Energy" in self.keys(): self._convert_column_uncertainty("Radiation Energy") - self.df.rename( - columns={"Radiation Energy": "Radiation Energy (keV)"}, inplace=True + self.df = self.df.rename( + columns={"Radiation Energy": "Radiation Energy (keV)"} ) if "Endpoint Energy" in self.keys(): self._convert_column_uncertainty("Endpoint Energy") - self.df.rename( - columns={"Endpoint Energy": "Endpoint Energy (keV)"}, inplace=True + self.df = self.df.rename( + columns={"Endpoint Energy": "Endpoint Energy (keV)"} ) if "Radiation Intensity (%)" in self.keys(): @@ -594,7 +601,7 @@ def _add_units_uncertainties(self): if "Dose" in self.keys(): self._convert_column_uncertainty("Dose") - self.df.rename(columns={"Dose": "Dose (MeV / Bq / s)"}, inplace=True) + self.df = self.df.rename(columns={"Dose": "Dose (MeV / Bq / s)"}) def _convert_column(self, col, function): """Convert column from string to another type.""" @@ -634,13 +641,8 @@ def _sort_columns(self): "Radiation Energy (keV)", "Radiation Intensity (%)", ] - new_cols = [] - for col in preferred_order: - if col in self.keys(): - new_cols.append(col) - for col in self.keys(): - if col not in new_cols: - new_cols.append(col) + new_cols = [col for col in preferred_order if col in self.keys()] + new_cols += [col for col in self.keys() if col not in new_cols] self.df = self.df[new_cols] @@ -713,8 +715,8 @@ def update(self, **kwargs): ) warnings.warn( 'query kwarg "decay" may not be working on NNDC, ' - + "and the user is advised to check the " - + '"Decay Mode" column of the resulting DataFrame' + "and the user is advised to check the " + '"Decay Mode" column of the resulting DataFrame' ) self._data["dmed"] = "enabled" self._data["dmn"] = WALLET_DECAY_MODE[kwargs["decay"].lower()] diff --git a/becquerel/tools/xcom.py b/becquerel/tools/xcom.py index eea40cdb..3fa9ea03 100644 --- a/becquerel/tools/xcom.py +++ b/becquerel/tools/xcom.py @@ -216,8 +216,8 @@ def _argument_type(arg): elif isinstance(arg, Iterable): return {"mixture": arg} raise XCOMInputError( - f"Cannot determine if argument {arg}" - + " is a symbol, Z, compound, or mixture" + f"Cannot determine if argument {arg} " + "is a symbol, Z, compound, or mixture" ) @staticmethod @@ -241,25 +241,21 @@ def _check_mixture(formulae): for formula in formulae: try: compound, weight = formula.split() - except AttributeError: + except AttributeError as exc: raise XCOMInputError( - 'Mixture formulae "{}" line "{}" must be a string'.format( - formulae, formula - ) - ) - except ValueError: + f'Mixture formulae "{formulae}" line "{formula}" must be a string' + ) from exc + except ValueError as exc: raise XCOMInputError( - 'Mixture formulae "{}" line "{}" must split into 2'.format( - formulae, formula - ) - ) + f'Mixture formulae "{formulae}" line "{formula}" must split into 2' + ) from exc _XCOMQuery._check_compound(compound) try: float(weight) - except (ValueError, TypeError): + except (ValueError, TypeError) as exc: raise XCOMInputError( f'Mixture formulae "{formulae}" has bad weight "{weight}"' - ) + ) from exc def update(self, **kwargs): """Update the search criteria. @@ -389,9 +385,8 @@ def _request(self): self._req = requests.post(self._url + self._method, data=self._data, timeout=15) if not self._req.ok or self._req.reason != "OK" or self._req.status_code != 200: raise XCOMRequestError( - "XCOM Request failed: reason={}, status_code={}".format( - self._req.reason, self._req.status_code - ) + f"XCOM Request failed: reason={self._req.reason}, " + f"status_code={self._req.status_code}" ) if "Error" in self._req.text: raise XCOMRequestError(f"XCOM returned an error:\n{self._req.text}") @@ -407,9 +402,8 @@ def _parse_text(self): self.df = tables[0] if len(self.df.keys()) != 1 + len(COLUMNS_SHORT): raise XCOMRequestError( - "Found {} columns but expected {}".format( - len(self.df.keys()), 1 + len(COLUMNS_SHORT) - ) + f"Found {len(self.df.keys())} columns but expected " + f"{1 + len(COLUMNS_SHORT)}" ) # remove 'edge' column self.df = self.df[self.df.keys()[1:]] diff --git a/examples/autocal.ipynb b/examples/autocal.ipynb index 267d2e49..c7aece2b 100644 --- a/examples/autocal.ipynb +++ b/examples/autocal.ipynb @@ -22,7 +22,7 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", + "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", @@ -217,7 +217,7 @@ ], "source": [ "counts = []\n", - "filename = os.path.join(os.path.dirname(bq.__file__), \"../tests/samples/sim_spec.spe\")\n", + "filename = Path(bq.__file__).parent / \"../tests/samples/sim_spec.spe\"\n", "spec = bq.Spectrum.from_file(filename)\n", "spec = spec.combine_bins(4)\n", "spec.bin_edges_raw *= 4\n", @@ -723,8 +723,9 @@ ], "source": [ "# read raw HPGe data\n", - "filename = os.path.join(\n", - " os.path.dirname(bq.__file__), \"../tests/samples/Mendocino_07-10-13_Acq-10-10-13.Spe\"\n", + "filename = (\n", + " Path(bq.__file__).parent /\n", + " \"../tests/samples/Mendocino_07-10-13_Acq-10-10-13.Spe\"\n", ")\n", "spec = bq.Spectrum.from_file(filename)\n", "plot_spec(spec)" @@ -897,9 +898,7 @@ ], "source": [ "counts = []\n", - "filename = os.path.join(\n", - " os.path.dirname(bq.__file__), \"../tests/samples/nai_detector.spe\"\n", - ")\n", + "filename = Path(bq.__file__).parent / \"../tests/samples/nai_detector.spe\"\n", "spec = bq.Spectrum.from_file(filename)\n", "plot_spec(spec)" ] @@ -1121,7 +1120,7 @@ ], "source": [ "counts = []\n", - "filename = os.path.join(os.path.dirname(bq.__file__), \"../tests/samples/SGM102432.spe\")\n", + "filename = Path(bq.__file__).parent / \"../tests/samples/SGM102432.spe\"\n", "spec = bq.Spectrum.from_file(filename)\n", "plot_spec(spec)" ] diff --git a/examples/misc.ipynb b/examples/misc.ipynb index 12b3e4d4..c91b4655 100644 --- a/examples/misc.ipynb +++ b/examples/misc.ipynb @@ -87,7 +87,7 @@ "xd = bq.xcom.fetch_xcom_data(name, e_range_kev=[50.0, 3000.0])\n", "\n", "# calculate mean free path\n", - "mfp_cm = 1 / (density * xd.total_wo_coh.values)\n", + "mfp_cm = 1 / (density * xd.total_wo_coh.to_numpy())\n", "\n", "plt.figure()\n", "plt.title(\"Mean free path of photons in \" + name)\n", @@ -428,7 +428,7 @@ "display(xd)\n", "\n", "# calculate mean free path\n", - "mfp_cm = 1 / (density * xd.total_wo_coh.values)\n", + "mfp_cm = 1 / (density * xd.total_wo_coh.to_numpy())\n", "mfp_m = mfp_cm / 100.0\n", "\n", "plt.figure()\n", diff --git a/examples/nndc.ipynb b/examples/nndc.ipynb index 8517b2ee..b79582c8 100644 --- a/examples/nndc.ipynb +++ b/examples/nndc.ipynb @@ -2466,7 +2466,7 @@ " \"Po-210\",\n", "]\n", "gammas = series_radiation(series)\n", - "gammas.sort_values(by=\"Radiation Energy (keV)\", inplace=True)\n", + "gammas = gammas.sort_values(by=\"Radiation Energy (keV)\")\n", "fields = [\n", " \"Z\",\n", " \"Element\",\n", @@ -2730,7 +2730,7 @@ ], "source": [ "alphas = series_radiation(series, rtype=\"Alpha\")\n", - "alphas.sort_values(by=\"Radiation Energy (keV)\", inplace=True)\n", + "alphas = alphas.sort_values(by=\"Radiation Energy (keV)\")\n", "display(alphas[fields])" ] }, @@ -3144,7 +3144,7 @@ "]\n", "branchings = [1, 1, 1, 1, 1, 1, 0.6406, 0.3594]\n", "gammas = series_radiation(series, branchings=branchings)\n", - "gammas.sort_values(by=\"Radiation Intensity (%)\", inplace=True, ascending=False)\n", + "gammas = gammas.sort_values(by=\"Radiation Intensity (%)\", ascending=False)\n", "fields = [\n", " \"Z\",\n", " \"Element\",\n", @@ -3360,7 +3360,7 @@ ], "source": [ "alphas = series_radiation(series, branchings=branchings, rtype=\"Alpha\")\n", - "alphas.sort_values(by=\"Radiation Energy (keV)\", inplace=True)\n", + "alphas = alphas.sort_values(by=\"Radiation Energy (keV)\")\n", "display(alphas[fields])" ] } diff --git a/examples/nndc_chart_of_nuclides.ipynb b/examples/nndc_chart_of_nuclides.ipynb index af9f6676..1787a409 100644 --- a/examples/nndc_chart_of_nuclides.ipynb +++ b/examples/nndc_chart_of_nuclides.ipynb @@ -164,7 +164,7 @@ " if np.isinf(max(hl)):\n", " return \"Stable\"\n", " if len(data) == 1:\n", - " mode = list(data[\"Decay Mode\"])[0]\n", + " mode = next(iter(data[\"Decay Mode\"]))\n", " if mode != \"\":\n", " return mode\n", " else:\n", @@ -174,7 +174,7 @@ " if len(data) == 0:\n", " return \"Unknown\"\n", " if len(data) == 1:\n", - " mode = list(data[\"Decay Mode\"])[0]\n", + " mode = next(iter(data[\"Decay Mode\"]))\n", " if mode != \"\":\n", " return mode\n", " else:\n", @@ -189,7 +189,7 @@ " else:\n", " return \"Unknown\"\n", " if len(data2) == 1:\n", - " mode = list(data2[\"Decay Mode\"])[0]\n", + " mode = next(iter(data2[\"Decay Mode\"]))\n", " if mode != \"\":\n", " return mode\n", " else:\n", @@ -199,7 +199,7 @@ " if len(data3) == 0:\n", " return \"Unknown\"\n", " if len(data3) == 1:\n", - " mode = list(data3[\"Decay Mode\"])[0]\n", + " mode = next(iter(data3[\"Decay Mode\"]))\n", " if mode != \"\":\n", " return mode\n", " else:\n", diff --git a/examples/rebinning.ipynb b/examples/rebinning.ipynb index c879ff53..747ff709 100644 --- a/examples/rebinning.ipynb +++ b/examples/rebinning.ipynb @@ -11,9 +11,8 @@ "import xarray as xr\n", "\n", "import becquerel as bq\n", - "from becquerel import Spectrum\n", + "from becquerel import Spectrum, rebin\n", "from becquerel import SpectrumPlotter as sp\n", - "from becquerel import rebin\n", "\n", "%matplotlib inline\n", "np.random.seed(0)\n", diff --git a/examples/spectrum.ipynb b/examples/spectrum.ipynb index 48e8a12c..7544d4cb 100644 --- a/examples/spectrum.ipynb +++ b/examples/spectrum.ipynb @@ -406,21 +406,18 @@ ], "source": [ "print(\n", - " \"Pottery spectrum: {:7.0f} counts in {:6.0f} s livetime\".format(\n", - " np.sum(spec.counts_vals), spec.livetime\n", - " )\n", + " f\"Pottery spectrum: {np.sum(spec.counts_vals):7.0f} counts in \"\n", + " f\"{spec.livetime:6.0f} s livetime\"\n", ")\n", "print(\n", - " \"Background spectrum: {:7.0f} counts in {:6.0f} s livetime\".format(\n", - " np.sum(bg.counts_vals), bg.livetime\n", - " )\n", + " f\"Background spectrum: {np.sum(bg.counts_vals):7.0f} counts in \"\n", + " f\"{bg.livetime:6.0f} s livetime\"\n", ")\n", "\n", "spec_and_bg = spec + bg\n", "print(\n", - " \"Combined spectrum: {:7.0f} counts in {:6.0f} s livetime\".format(\n", - " np.sum(spec_and_bg.counts_vals), spec_and_bg.livetime\n", - " )\n", + " f\"Combined spectrum: {np.sum(spec_and_bg.counts_vals):7.0f} counts in \"\n", + " f\"{spec_and_bg.livetime:6.0f} s livetime\"\n", ")" ] }, diff --git a/pyproject.toml b/pyproject.toml index 36c9f0c0..17a0c2c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,117 @@ -[tool.black] -target-version = ['py38', 'py39', 'py310', 'py311'] -include = '\.pyi?$' -exclude = ''' -/( - \.eggs - | \.git - | \.pytest_cache - | htmlcov - | figures - | build - | dist - # Specific to this package - | becquerel.egg-info -)/ -''' + +# all of ruff's settings +# https://docs.astral.sh/ruff/settings/ +# all of ruff's rules: +# https://docs.astral.sh/ruff/rules/ + +[tool.ruff] +namespace-packages = ["becquerel"] +target-version = "py38" +fix = true +show-fixes = true +preview = false + +# format the same as black +indent-width = 4 +line-length = 88 + +[tool.ruff.format] +indent-style = "space" +quote-style = "double" +line-ending = "auto" +skip-magic-trailing-comma = false +docstring-code-format = true +docstring-code-line-length = "dynamic" + +[tool.ruff.lint] +select = [ + "F", # pyflakes + "E", # pycodestyle errors + "W", # pycodestyle warnings + # "C", # mccabe + "I", # isort + # "N", # pep8-naming + # "D", # pydocstyle + "UP", # pyupgrade + "YTT", # flake8-2020 + # "ANN", # flake8-annotations + "ASYNC", # flake8-async + "S", # flake8-bandit + "BLE", # flake8-blind-except + # "FBT", # flake8-boolean-trap + "B", # flake8-bugbear + "A", # flake8-builtins + # "COM", # flake8-commas + "C4", # flake8-comprehensions + # "DTZ", # flake8-datetimez + "T10", # flake8-debugger + # "EM", # flake8-errmsg + "EXE", # flake8-executable + "FA", # flake8-future-annotations + "ISC", # flake8-implicit-str-concat + "ICN", # flake8-import-conventions + "PIE", # flake8-pie + # "T20", # flake8-print + "PYI", # flake8-pyi + # "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RSE", # flake8-raise + # "RET", # flake8-return + # "SLF", # flake8-self + # "SIM", # flake8-simplify + # "TID", # flake8-tidy-imports + "TCH", # flake8-type-checking + "INT", # flake8-gettext + # "ARG", # flake8-unused-arguments + "PTH", # flake8-use-pathlib + # "TD", # flake8-todos + # "FIX", # flake8-fixme + # "ERA", # eradicate + "PD", # pandas-vet + "PGH", # pygrep-hooks + # "PL", # pylint + "TRY", # tryceratops + "FLY", # flynt + "NPY", # NumPy-specific rules + "PERF", # Perflint + # "FURB", # refurb + "RUF", # Ruff-specific rules +] + +ignore = [ + "S101", # Use of `assert` detected + "B015", # Pointless comparison. Did you mean to assign a value? Otherwise, prepend `assert` or remove it. + "B018", # Found useless expression. Either assign it to a variable or remove it. + "B028", # No explicit `stacklevel` keyword argument found + "PD901", # Avoid using the generic variable name `df` for DataFrames + "TRY003", # Avoid specifying long messages outside the exception class + "NPY002", # Replace legacy `np.random.poisson` call with `np.random.Generator` + "PERF203", # `try`-`except` within a loop incurs performance overhead + "FURB101", # `open` and `read` should be replaced by `Path("CONTRIBUTING.md").read_text()` + "FURB113", # Use `s.extend(...)` instead of repeatedly calling `s.append()` + "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` +] + +exclude = [] + +unfixable = [] + +[tool.ruff.lint.flake8-quotes] +docstring-quotes = "double" + +[tool.ruff.lint.isort] +lines-after-imports = -1 +known-first-party = ["becquerel"] + +[tool.ruff.lint.pycodestyle] +max-doc-length = 88 + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.pylint] +max-args = 15 +max-branches = 15 +max-locals = 25 +max-nested-blocks = 15 +max-statements = 100 diff --git a/requirements-dev.txt b/requirements-dev.txt index 348d9e68..bc1f8267 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,11 +1,10 @@ -r requirements.txt bump2version coverage -flake8>=4.0.1 pre-commit pytest -pytest-black pytest-cov pytest-rerunfailures pytest-xdist requests +ruff diff --git a/setup.cfg b/setup.cfg index bd6a494f..6901ccf9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ test = pytest [tool:pytest] -addopts = --black --cov=becquerel --cov-report term --cov-report html:htmlcov -m "not plottest" +addopts = --cov=becquerel --cov-report term --cov-report html:htmlcov -m "not plottest" markers = webtest: test requires internet connection plottest: test will produce plot figures @@ -12,7 +12,3 @@ filterwarnings = [coverage:run] relative_files = True - -[flake8] -max-line-length = 88 -exclude = .eggs diff --git a/setup.py b/setup.py index d33aa151..7289fdea 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ import importlib.util import site import sys +from pathlib import Path from setuptools import find_packages, setup @@ -21,19 +22,17 @@ site.ENABLE_USER_SITE = "--user" in sys.argv[1:] # remove package title from description -with open("README.md") as fh: +with Path("README.md").open() as fh: README = "\n".join(fh.readlines()[2:]) -with open("CONTRIBUTING.md") as fh: +with Path("CONTRIBUTING.md").open() as fh: CONTRIBUTING = fh.read() -with open("requirements.txt") as fh: - REQUIREMENTS = [_line for _line in fh.readlines() if _line] +with Path("requirements.txt").open() as fh: + REQUIREMENTS = [_line for _line in fh if _line] -with open("requirements-dev.txt") as fh: - REQUIREMENTS_DEV = [ - line.strip() for line in fh.readlines() if not line.startswith("-r") - ] +with Path("requirements-dev.txt").open() as fh: + REQUIREMENTS_DEV = [line.strip() for line in fh if not line.startswith("-r")] # make long description from README and CONTRIBUTING # but move copyright notice to the end @@ -42,7 +41,7 @@ ) setup( - name=METADATA.__name__, + name="becquerel", version=METADATA.__version__, description=METADATA.__description__, long_description=LONG_DESCRIPTION, diff --git a/tests/autocal_test.py b/tests/autocal_test.py index 63f145c3..d31926d8 100644 --- a/tests/autocal_test.py +++ b/tests/autocal_test.py @@ -1,19 +1,17 @@ """Test PeakFinder and AutoCalibrator classes.""" -import os - import matplotlib.pyplot as plt import numpy as np import pytest +from parsers_test import SAMPLES_PATH import becquerel as bq # read in spectra -SAMPLES_PATH = os.path.join(os.path.dirname(__file__), "samples") -filename1 = os.path.join(SAMPLES_PATH, "sim_spec.spe") -filename2 = os.path.join(SAMPLES_PATH, "Mendocino_07-10-13_Acq-10-10-13.Spe") -filename3 = os.path.join(SAMPLES_PATH, "nai_detector.spe") -filename4 = os.path.join(SAMPLES_PATH, "SGM102432.spe") +filename1 = SAMPLES_PATH / "sim_spec.spe" +filename2 = SAMPLES_PATH / "Mendocino_07-10-13_Acq-10-10-13.Spe" +filename3 = SAMPLES_PATH / "nai_detector.spe" +filename4 = SAMPLES_PATH / "SGM102432.spe" spec1 = bq.Spectrum.from_file(filename1) spec2 = bq.Spectrum.from_file(filename2) diff --git a/tests/calibration_test.py b/tests/calibration_test.py index cae6bbd7..110493ba 100644 --- a/tests/calibration_test.py +++ b/tests/calibration_test.py @@ -1,7 +1,5 @@ """Test Calibration class.""" -import os - import matplotlib.pyplot as plt import numpy as np import pytest @@ -312,7 +310,7 @@ def make_calibration(name, args): @pytest.mark.parametrize("name, args", name_args) def test_calibration(name, args): """Test the Calibration class.""" - fname = os.path.join(TEST_OUTPUTS, f"calibration__init__{name}.h5") + fname = TEST_OUTPUTS / f"calibration__init__{name}.h5" # test __init__() cal = make_calibration(name, args) # test protections on setting parameters @@ -338,7 +336,7 @@ def test_calibration(name, args): @pytest.mark.parametrize("name, args", name_args) def test_calibration_set_add_points(name, args): """Test Calibration.set_points and add_points methods.""" - fname = os.path.join(TEST_OUTPUTS, f"calibration__add_points__{name}.h5") + fname = TEST_OUTPUTS / f"calibration__add_points__{name}.h5" cal = make_calibration(name, args) # test set_points cal.set_points() @@ -448,7 +446,7 @@ def test_calibration_fit_from_points(name, args): plt.xlim(0) plt.ylim(0) plt.legend() - plt.savefig(os.path.join(TEST_OUTPUTS, f"calibration__fit__{name}.png")) + plt.savefig(TEST_OUTPUTS / f"calibration__fit__{name}.png") # Test statistics assert len(cal1.fit_y) > 0 @@ -503,7 +501,7 @@ def test_calibration_domain_range(): def test_calibration_inverse(): """Test calibrations with and without inverse expression.""" - fname = os.path.join(TEST_OUTPUTS, "calibration__inverse.h5") + fname = TEST_OUTPUTS / "calibration__inverse.h5" # cal1 has an explicit inverse expression, cal2 does not cal1 = Calibration( @@ -620,7 +618,7 @@ def test_calibration_interpolation(): def test_calibration_read_failures(): """Test miscellaneous HDF5 reading failures.""" - fname = os.path.join(TEST_OUTPUTS, "calibration__read_failures.h5") + fname = TEST_OUTPUTS / "calibration__read_failures.h5" cal = Calibration.from_linear([2.0, 3.0]) cal.add_points([0, 1000, 2000], [0, 1000, 2000]) diff --git a/tests/element_test.py b/tests/element_test.py index fb77cde8..31cc5488 100644 --- a/tests/element_test.py +++ b/tests/element_test.py @@ -130,7 +130,7 @@ def test_element(z, sym, name): args.extend([z, str(z)]) print(args) for arg in args: - print("") + print() print("arg: ", arg) elem = element.Element(arg) print(elem) diff --git a/tests/fitting_test.py b/tests/fitting_test.py index 6d2dd3f4..70f5f495 100644 --- a/tests/fitting_test.py +++ b/tests/fitting_test.py @@ -1,22 +1,19 @@ -import glob -import os from copy import deepcopy import lmfit import numpy as np import pytest +from parsers_test import SAMPLES_PATH import becquerel as bq -SAMPLES_PATH = os.path.join(os.path.dirname(__file__), "samples") - # TODO: use these for fitting actual data SAMPLES = {} for extension in [".spe", ".spc", ".cnf"]: - filenames = glob.glob(os.path.join(SAMPLES_PATH, "*.*")) + filenames = SAMPLES_PATH.glob("*.*") filenames_filtered = [] for filename in filenames: - fname, ext = os.path.splitext(filename) + ext = filename.suffix if ext.lower() == extension: filenames_filtered.append(filename) SAMPLES[extension] = filenames_filtered diff --git a/tests/h5_tools_test.py b/tests/h5_tools_test.py index ee63143f..a74c68ba 100644 --- a/tests/h5_tools_test.py +++ b/tests/h5_tools_test.py @@ -1,6 +1,6 @@ """Test HDF5 I/O tools.""" -import os +from pathlib import Path import h5py import numpy as np @@ -8,9 +8,8 @@ from becquerel.io.h5 import ensure_string, is_h5_filename, open_h5, read_h5, write_h5 -TEST_OUTPUTS = os.path.join(os.path.split(__file__)[0], "test_outputs") -if not os.path.exists(TEST_OUTPUTS): - os.mkdir(TEST_OUTPUTS) +TEST_OUTPUTS = Path(__file__).parent / "test_outputs" +TEST_OUTPUTS.mkdir(exist_ok=True) DSETS = { "dset_1d": np.ones(100, dtype=int), @@ -56,7 +55,7 @@ def write_test_open_h5_file(fname): def test_open_h5(): """Test open_h5 for different inputs.""" - fname = os.path.join(TEST_OUTPUTS, "io_h5__test_open_h5.h5") + fname = TEST_OUTPUTS / "io_h5__test_open_h5.h5" # filename cases write_test_open_h5_file(fname) @@ -143,7 +142,7 @@ def check_dsets_attrs(dsets1, attrs1, dsets2, attrs2): @pytest.mark.parametrize("attrs", [ATTRS]) def test_write_h5_filename(dsets, attrs): """Write data to h5 given its filename.""" - fname = os.path.join(TEST_OUTPUTS, "io_h5__test_write_h5_filename.h5") + fname = TEST_OUTPUTS / "io_h5__test_write_h5_filename.h5" write_h5(fname, dsets, attrs) @@ -151,7 +150,7 @@ def test_write_h5_filename(dsets, attrs): @pytest.mark.parametrize("attrs", [ATTRS]) def test_write_h5_file(dsets, attrs): """Write data to h5 given an open h5py.File.""" - fname = os.path.join(TEST_OUTPUTS, "io_h5__test_write_h5_file.h5") + fname = TEST_OUTPUTS / "io_h5__test_write_h5_file.h5" with h5py.File(fname, "w") as file: write_h5(file, dsets, attrs) @@ -160,7 +159,7 @@ def test_write_h5_file(dsets, attrs): @pytest.mark.parametrize("attrs", [ATTRS]) def test_write_h5_group(dsets, attrs): """Write data to h5 given an h5py.Group.""" - fname = os.path.join(TEST_OUTPUTS, "io_h5__test_write_h5_group.h5") + fname = TEST_OUTPUTS / "io_h5__test_write_h5_group.h5" with h5py.File(fname, "w") as file: group = file.create_group("test_group") write_h5(group, dsets, attrs) @@ -170,7 +169,7 @@ def test_write_h5_group(dsets, attrs): @pytest.mark.parametrize("attrs", [ATTRS]) def test_read_h5_filename(dsets, attrs): """Read data from h5 given its filename.""" - fname = os.path.join(TEST_OUTPUTS, "io_h5__test_write_h5_filename.h5") + fname = TEST_OUTPUTS / "io_h5__test_write_h5_filename.h5" dsets2, attrs2, skipped = read_h5(fname) check_dsets_attrs(dsets, attrs, dsets2, attrs2) assert len(skipped) == 0 @@ -180,7 +179,7 @@ def test_read_h5_filename(dsets, attrs): @pytest.mark.parametrize("attrs", [ATTRS]) def test_read_h5_file(dsets, attrs): """Read data from h5 given an open h5py.File.""" - fname = os.path.join(TEST_OUTPUTS, "io_h5__test_write_h5_file.h5") + fname = TEST_OUTPUTS / "io_h5__test_write_h5_file.h5" with h5py.File(fname, "r") as file: dsets2, attrs2, skipped = read_h5(file) check_dsets_attrs(dsets, attrs, dsets2, attrs2) @@ -191,7 +190,7 @@ def test_read_h5_file(dsets, attrs): @pytest.mark.parametrize("attrs", [ATTRS]) def test_read_h5_group(dsets, attrs): """Read data from h5 given an h5py.Group.""" - fname = os.path.join(TEST_OUTPUTS, "io_h5__test_write_h5_group.h5") + fname = TEST_OUTPUTS / "io_h5__test_write_h5_group.h5" with h5py.File(fname, "r") as file: group = file["test_group"] dsets2, attrs2, skipped = read_h5(group) @@ -203,7 +202,7 @@ def test_read_h5_group(dsets, attrs): @pytest.mark.parametrize("attrs", [ATTRS]) def test_read_h5_ignore_group(dsets, attrs): """Read data from h5 and ignore a data group within the group.""" - fname = os.path.join(TEST_OUTPUTS, "io_h5__write_h5_ignore_group.h5") + fname = TEST_OUTPUTS / "io_h5__write_h5_ignore_group.h5" # write the file with an extra group with h5py.File(fname, "w") as file: diff --git a/tests/isotope_qty_test.py b/tests/isotope_qty_test.py index 9b2b8d77..17527802 100644 --- a/tests/isotope_qty_test.py +++ b/tests/isotope_qty_test.py @@ -347,7 +347,7 @@ def test_isotopequantity_eq(iq): assert iq == iq3 -@pytest.mark.parametrize("f", [2, 0.5, 3.14]) +@pytest.mark.parametrize("f", [2, 0.5, 3.7]) def test_isotopequantity_mul_div(iq, f): """Test IsotopeQuantity multiplication and division magic methods""" diff --git a/tests/isotope_test.py b/tests/isotope_test.py index 7c2f0a24..b240e5a4 100644 --- a/tests/isotope_test.py +++ b/tests/isotope_test.py @@ -57,7 +57,7 @@ def test_isotope_init_args(iso_str, sym, A, m): if isomer == "" or isomer is None: args_list.append((elem, mass)) for args in args_list: - print("") + print() print(args) i = isotope.Isotope(*args) print(i) @@ -159,7 +159,7 @@ def test_isotope_init_str(iso_str, sym, A, m): ] for iso in iso_tests: for iso2 in [iso, iso.upper(), iso.lower()]: - print("") + print() print(f"{sym}-{mass_number}: {iso2}") i = isotope.Isotope(iso2) print(i) diff --git a/tests/materials_test.py b/tests/materials_test.py index 2484c4c2..e595b14b 100644 --- a/tests/materials_test.py +++ b/tests/materials_test.py @@ -1,8 +1,8 @@ """Test NIST material data queries.""" import json -import os import warnings +from pathlib import Path import pytest from utils import xcom_is_up @@ -64,17 +64,17 @@ def test_z_out_of_range(self): def test_materials(): """Test fetch_materials.""" fetch_materials() - assert os.path.exists(materials.FILENAME) + assert materials.FILENAME.exists() @pytest.mark.webtest @pytest.mark.skipif(not xcom_is_up(), reason="XCOM is down.") def test_materials_force(): """Test fetch_materials with force=True.""" - assert os.path.exists(materials.FILENAME) + assert materials.FILENAME.exists() with pytest.warns(MaterialsWarning) as record: fetch_materials(force=True) - if not os.path.exists(materials_compendium.FNAME): + if not materials_compendium.FNAME.exists(): assert len(record) == 2, ( "Expected two MaterialsWarnings to be raised; " f"got {_get_warning_messages(record)}" @@ -84,22 +84,22 @@ def test_materials_force(): "Expected one MaterialsWarning to be raised; " f"got {_get_warning_messages(record)}" ) - assert os.path.exists(materials.FILENAME) + assert materials.FILENAME.exists() def test_materials_dummy_csv(): """Test fetch_materials with a dummy materials.csv file.""" # point to and generate a dummy CSV file fname_orig = materials.FILENAME - materials.FILENAME = fname_orig[:-4] + "_dummy.csv" - if os.path.exists(materials.FILENAME): - os.remove(materials.FILENAME) - with open(materials.FILENAME, "w") as f: + materials.FILENAME = Path(str(fname_orig)[:-4] + "_dummy.csv") + if materials.FILENAME.exists(): + materials.FILENAME.unlink() + with materials.FILENAME.open("w") as f: print("%name,formula,density,weight fractions,source", file=f) print('Dummy,-,1.0,"H 0.5;O 0.5","dummy entry"', file=f) fetch_materials() # remove the dummy file and point back to original - os.remove(materials.FILENAME) + materials.FILENAME.unlink() materials.FILENAME = fname_orig @@ -112,7 +112,7 @@ def test_materials_dummy_compendium_pre2022(): """ # point to an generate a dummy JSON file fname_orig = materials_compendium.FNAME - materials_compendium.FNAME = fname_orig[:-5] + "_dummy.json" + materials_compendium.FNAME = Path(str(fname_orig)[:-5] + "_dummy.json") data = [ { "Density": 8.4e-5, @@ -139,14 +139,14 @@ def test_materials_dummy_compendium_pre2022(): "Name": "Nitrogen", }, ] - with open(materials_compendium.FNAME, "w") as f: + with materials_compendium.FNAME.open("w") as f: json.dump(data, f, indent=4) # Check that no warning is raised with warnings.catch_warnings(): warnings.simplefilter("error") materials._load_and_compile_materials() # remove the dummy file and point back to original - os.remove(materials_compendium.FNAME) + materials_compendium.FNAME.unlink() materials_compendium.FNAME = fname_orig @@ -159,7 +159,7 @@ def test_materials_dummy_compendium_2022(): """ # point to an generate a dummy JSON file fname_orig = materials_compendium.FNAME - materials_compendium.FNAME = fname_orig[:-5] + "_dummy.json" + materials_compendium.FNAME = Path(str(fname_orig)[:-5] + "_dummy.json") data = { "siteVersion": "0.0.0", "data": [ @@ -189,7 +189,7 @@ def test_materials_dummy_compendium_2022(): }, ], } - with open(materials_compendium.FNAME, "w") as f: + with materials_compendium.FNAME.open("w") as f: json.dump(data, f, indent=4) # Check that no warning is raised with warnings.catch_warnings(): @@ -197,12 +197,12 @@ def test_materials_dummy_compendium_2022(): materials._load_and_compile_materials() # remove siteVersion and make sure there is an error raised del data["siteVersion"] - with open(materials_compendium.FNAME, "w") as f: + with materials_compendium.FNAME.open("w") as f: json.dump(data, f, indent=4) with pytest.raises(MaterialsError): materials._load_and_compile_materials() # remove the dummy file and point back to original - os.remove(materials_compendium.FNAME) + materials_compendium.FNAME.unlink() materials_compendium.FNAME = fname_orig @@ -215,14 +215,14 @@ def test_materials_dummy_compendium_error(): """ # point to an generate a dummy JSON file fname_orig = materials_compendium.FNAME - materials_compendium.FNAME = fname_orig[:-5] + "_dummy.json" + materials_compendium.FNAME = Path(str(fname_orig)[:-5] + "_dummy.json") data = None - with open(materials_compendium.FNAME, "w") as f: + with materials_compendium.FNAME.open("w") as f: json.dump(data, f, indent=4) with pytest.raises(MaterialsError): materials._load_and_compile_materials() # remove the dummy file and point back to original - os.remove(materials_compendium.FNAME) + materials_compendium.FNAME.unlink() materials_compendium.FNAME = fname_orig @@ -232,9 +232,9 @@ def test_materials_no_compendium(): """Test fetch_materials with no Compendium JSON file.""" # point to a dummy JSON file that does not exist fname_orig = materials_compendium.FNAME - materials_compendium.FNAME = fname_orig[:-5] + "_dummy.json" - if os.path.exists(materials_compendium.FNAME): - os.remove(materials_compendium.FNAME) + materials_compendium.FNAME = Path(str(fname_orig)[:-5] + "_dummy.json") + if materials_compendium.FNAME.exists(): + materials_compendium.FNAME.unlink() with pytest.warns(MaterialsWarning) as record: materials_compendium.fetch_compendium_data() assert len(record) == 1, ( @@ -249,15 +249,15 @@ def test_remove_materials_csv(): """Test remove_materials_csv.""" # point to and generate a dummy CSV file fname_orig = materials.FILENAME - materials.FILENAME = fname_orig[:-4] + "_dummy.csv" - if os.path.exists(materials.FILENAME): - os.remove(materials.FILENAME) - with open(materials.FILENAME, "w") as f: - print("", file=f) + materials.FILENAME = Path(str(fname_orig)[:-4] + "_dummy.csv") + if materials.FILENAME.exists(): + materials.FILENAME.unlink() + with materials.FILENAME.open("w") as f: + print(file=f) remove_materials_csv() - assert not os.path.exists(materials.FILENAME) + assert not materials.FILENAME.exists() # make sure remove works if the file does not exist remove_materials_csv() - assert not os.path.exists(materials.FILENAME) + assert not materials.FILENAME.exists() # point back to original file materials.FILENAME = fname_orig diff --git a/tests/nndc_test.py b/tests/nndc_test.py index 329facbd..80e4b6a0 100644 --- a/tests/nndc_test.py +++ b/tests/nndc_test.py @@ -633,9 +633,9 @@ def test_wallet_j_10(self): @pytest.mark.skip( reason='query kwarg "decay" seems to not be working ' - + "on NNDC, and as a result too many results " - + "are returned for this test, causing an " - + "NNDCRequestError" + "on NNDC, and as a result too many results " + "are returned for this test, causing an " + "NNDCRequestError" ) def test_wallet_decay_SF(self): """Test fetch_wallet_card: decay='SF'..............................""" diff --git a/tests/parsers_test.py b/tests/parsers_test.py index 4f82d519..543f93c9 100644 --- a/tests/parsers_test.py +++ b/tests/parsers_test.py @@ -1,7 +1,6 @@ """Test becquerel spectrum file parsers.""" -import glob -import os +from pathlib import Path import matplotlib.pyplot as plt import numpy as np @@ -9,13 +8,13 @@ import becquerel as bq -SAMPLES_PATH = os.path.join(os.path.dirname(__file__), "samples") +SAMPLES_PATH = Path(__file__).parent / "samples" SAMPLES = {} for extension in [".spe", ".spc", ".cnf", ".h5", ".iec"]: - filenames = glob.glob(os.path.join(SAMPLES_PATH + "*", "*.*")) + filenames = SAMPLES_PATH.glob("*.*") filenames_filtered = [] for filename in filenames: - fname, ext = os.path.splitext(filename) + fname, ext = filename.with_suffix(""), filename.suffix if ext.lower() == extension: filenames_filtered.append(filename) SAMPLES[extension] = filenames_filtered @@ -29,9 +28,7 @@ def run_parser(self, read_fn, extension): filenames = SAMPLES.get(extension, []) assert len(filenames) >= 1 for filename in filenames: - fname, ext = os.path.splitext(filename) - path, fname = os.path.split(fname) - print("") + print() print(filename) data, cal = read_fn(filename) print(data, cal) @@ -74,16 +71,15 @@ def run_parser(self, read_fn, extension, write=False): """Run the test for the given class and file extension.""" try: plt.figure() - except Exception: + except Exception: # noqa: BLE001 (blind exception) # TclError on CI bc no display. skip the test return plt.title(f"Testing {extension}") filenames = SAMPLES.get(extension, []) assert len(filenames) >= 1 for filename in filenames: - fname, ext = os.path.splitext(filename) - path, fname = os.path.split(fname) - print("") + fname, ext = filename.name, filename.suffix + print() print(filename) data, cal = read_fn(filename) spec = bq.Spectrum(**data) @@ -99,9 +95,9 @@ def run_parser(self, read_fn, extension, write=False): plt.ylabel("Counts/keV/sec") plt.xlim(0, 2800) if write: - writename = os.path.join(".", fname + "_copy" + ext) + writename = Path.cwd() / (fname + "_copy" + ext) spec.write(writename) - os.remove(writename) + writename.unlink() plt.legend(prop={"size": 8}) plt.show() diff --git a/tests/plotting_test.py b/tests/plotting_test.py index 23328d56..068c1ba4 100644 --- a/tests/plotting_test.py +++ b/tests/plotting_test.py @@ -1,6 +1,6 @@ """Test core.plotting""" -import matplotlib +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pytest @@ -514,12 +514,12 @@ def test_errornone(uncal_spec): polys = 0 lines = 0 for i in ax.get_children(): - if type(i) is matplotlib.collections.LineCollection: + if type(i) is mpl.collections.LineCollection: colls = colls + 1 - if type(i) is matplotlib.collections.PolyCollection: + if type(i) is mpl.collections.PolyCollection: polys = polys + 1 - if type(i) is matplotlib.lines.Line2D: + if type(i) is mpl.lines.Line2D: lines = lines + 1 assert colls == 0 assert polys == 0 @@ -535,9 +535,9 @@ def test_errorbars(uncal_spec): colls = 0 lines = 0 for i in ax.get_children(): - if type(i) is matplotlib.collections.LineCollection: + if type(i) is mpl.collections.LineCollection: colls = colls + 1 - if type(i) is matplotlib.lines.Line2D: + if type(i) is mpl.lines.Line2D: lines = lines + 1 assert colls == 1 assert lines >= 1 @@ -552,9 +552,9 @@ def test_errorband(uncal_spec): colls = 0 lines = 0 for i in ax.get_children(): - if type(i) is matplotlib.collections.PolyCollection: + if type(i) is mpl.collections.PolyCollection: colls = colls + 1 - if type(i) is matplotlib.lines.Line2D: + if type(i) is mpl.lines.Line2D: lines = lines + 1 assert colls == 1 assert lines == 1 diff --git a/tests/spectrum_io_test.py b/tests/spectrum_io_test.py index ccbf7d62..126769d4 100644 --- a/tests/spectrum_io_test.py +++ b/tests/spectrum_io_test.py @@ -1,6 +1,5 @@ """Test Spectrum I/O for different file types.""" -import os from pathlib import Path import numpy as np @@ -44,7 +43,7 @@ def test_spectrum_from_file_raises(): def test_write_h5(kind): """Test writing different Spectrums to HDF5 files.""" spec = make_spec(kind, lt=600.0) - fname = os.path.join(TEST_OUTPUTS, "spectrum_io__test_write_h5__" + kind + ".h5") + fname = TEST_OUTPUTS / ("spectrum_io__test_write_h5__" + kind + ".h5") spec.write(fname) @@ -62,7 +61,7 @@ def test_write_h5(kind): ) def test_from_file_h5(kind): """Test Spectrum.from_file works for HDF5 files.""" - fname = os.path.join(TEST_OUTPUTS, "spectrum_io__test_write_h5__" + kind + ".h5") + fname = TEST_OUTPUTS / ("spectrum_io__test_write_h5__" + kind + ".h5") spec = bq.Spectrum.from_file(fname) assert spec.livetime is not None if kind == "applied_energy_cal": @@ -76,10 +75,8 @@ def test_spectrum_samples_write_read_h5(extension): assert len(filenames) >= 1 for filename in filenames: spec = bq.Spectrum.from_file(filename) - fname2 = os.path.splitext(filename)[0] + ".h5" - fname2 = os.path.join( - TEST_OUTPUTS, "spectrum_io__sample_write_h5__" + os.path.split(fname2)[1] - ) + fname2 = Path(filename).with_suffix(".h5") + fname2 = TEST_OUTPUTS / ("spectrum_io__sample_write_h5__" + fname2.name) spec.write(fname2) spec = bq.Spectrum.from_file(fname2) assert spec.livetime is not None @@ -87,9 +84,7 @@ def test_spectrum_samples_write_read_h5(extension): def test_from_file_cal_kwargs(): """Test Spectrum.from_file overrides calibration with cal_kwargs.""" - fname = os.path.join( - TEST_OUTPUTS, "spectrum_io__test_write_h5__applied_energy_cal.h5" - ) + fname = TEST_OUTPUTS / "spectrum_io__test_write_h5__applied_energy_cal.h5" domain = [-100, 10000] rng = [-10, 1000] params = [0.6] diff --git a/tests/spectrum_test.py b/tests/spectrum_test.py index 2a84493c..f5d3e90d 100644 --- a/tests/spectrum_test.py +++ b/tests/spectrum_test.py @@ -1209,12 +1209,12 @@ def include_overflows(request): def test_spectrum_rebin_success( rebin_spectrum_success, rebin_new_edges, rebin_method, include_overflows ): - kwargs = dict( - out_edges=rebin_new_edges, - method=rebin_method, - zero_pad_warnings=False, - include_overflows=include_overflows, - ) + kwargs = { + "out_edges": rebin_new_edges, + "method": rebin_method, + "zero_pad_warnings": False, + "include_overflows": include_overflows, + } if (rebin_spectrum_success._counts is None) and (rebin_method == "listmode"): with pytest.warns(bq.SpectrumWarning): spec = rebin_spectrum_success.rebin(**kwargs)