diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b8bbfa..338203e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,45 +16,19 @@ repos: - id: mixed-line-ending args: ['--fix=no'] -- repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - additional_dependencies: - - flake8-comprehensions - - flake8-logging-format - - flake8-builtins - - flake8-eradicate - - pep8-naming - - flake8-pytest - - flake8-docstrings - - flake8-rst-docstrings - - flake8-rst - - flake8-copyright -# - flake8-ownership - - flake8-markdown - - flake8-bugbear - - flake8-comprehensions - - flake8-print - - -- repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.4.2 - hooks: - - id: black - -- repo: https://github.com/PyCQA/isort - rev: 5.13.2 - hooks: - - id: isort - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.10.0 hooks: - id: rst-backticks -- repo: https://github.com/asottile/pyupgrade - rev: v3.16.0 +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.4.10 hooks: - - id: pyupgrade - args: [--py38-plus] + # Run the linter. + - id: ruff + args: [--fix] + + # Run the formatter. + - id: ruff-format diff --git a/docs/conf.py b/docs/conf.py index 6b22837..fd3b02e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,8 +1,8 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals +"""Configuration for docs.""" import os -from datetime import datetime +from datetime import datetime, timezone + from py21cmsense import __version__ extensions = [ @@ -35,9 +35,9 @@ source_suffix = ".rst" master_doc = "index" project = "21cmSense" -year = str(datetime.now().year) +year = str(datetime.now(tz=timezone.utc).year) author = "Jonathan Pober and Steven Murray" -copyright = "{0}, {1}".format(year, author) +copyright = f"{year}, {author}" version = release = __version__ templates_path = ["templates"] @@ -68,9 +68,7 @@ napoleon_use_rtype = False napoleon_use_param = False -mathjax_path = ( - "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js" -) +mathjax_path = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js" exclude_patterns = [ "_build", diff --git a/pyproject.toml b/pyproject.toml index 05be82d..33a5eaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dynamic = ['version'] # Add here dependencies of your project (semicolon/line-separated), e.g. dependencies = [ - "numpy", + "numpy<2.0", # Restriction can be lifted once pyuvdata is good with numpy 2 "scipy", "future", "click", @@ -64,6 +64,7 @@ dev = [ "21cmSense[docs,test]", "pre-commit", "commitizen", + "ruff" ] [project.scripts] @@ -72,6 +73,61 @@ sense = "py21cmsense.cli:main" [tool.setuptools_scm] +[tool.ruff] +line-length = 100 +target-version = "py39" + +[tool.ruff.lint] +extend-select = [ + "UP", # pyupgrade + "E", # pycodestyle + "W", # pycodestyle warning + "C90", # mccabe complexity + "I", # isort + "N", # pep8-naming + "D", # docstyle + "B", # bugbear + "A", # builtins + "C4", # comprehensions + "DTZ", # datetime + "FA", # future annotations + "PIE", # flake8-pie + "T", # print statements + "PT", # pytest-style + "Q", # quotes + "SIM", # simplify + # "PTH", # use Pathlib + "ERA", # kill commented code + "NPY", # numpy-specific rules + "PERF", # performance + "RUF", # ruff-specific rules +] +ignore = [ + "DTZ007", # use %z in strptime + "A003", # class attribute shadows python builtin + "B008", # function call in argument defaults + "N802", # TODO: remove this (function name should be lower-case) + "B019", # no using lru_cache on methods + "G004", # logging uses f-string + "D107", # no docstring in __init__ +] +[tool.ruff.lint.per-file-ignores] +"tests/*.py" = [ + "D103", # ignore missing docstring in tests + "DTZ", # ignore datetime in tests + "T", # print statements +] +"docs/conf.py" = [ + "A", # conf.py can shadow builtins + "ERA", +] + +[tool.ruff.lint.pydocstyle] +convention = 'numpy' +property-decorators = ['property', 'functools.cached_property'] + +[tool.ruff.lint.mccabe] +max-complexity = 15 [tool.pytest.ini_options] # Options for py.test: diff --git a/src/py21cmsense/__init__.py b/src/py21cmsense/__init__.py index 6fde05b..7173c7d 100644 --- a/src/py21cmsense/__init__.py +++ b/src/py21cmsense/__init__.py @@ -9,6 +9,17 @@ finally: del version, PackageNotFoundError +__all__ = [ + "data", + "theory", + "yaml", + "hera", + "BaselineRange", + "GaussianBeam", + "Observation", + "Observatory", + "PowerSpectrum", +] from . import data, theory, yaml from .antpos import hera from .baseline_filters import BaselineRange diff --git a/src/py21cmsense/_utils.py b/src/py21cmsense/_utils.py index 02bad3c..d218bb6 100644 --- a/src/py21cmsense/_utils.py +++ b/src/py21cmsense/_utils.py @@ -6,8 +6,6 @@ from astropy.time import Time from pyuvdata import utils as uvutils -from . import config - def between(xmin, xmax): """Return an attrs validation function that checks a number is within bounds.""" @@ -19,12 +17,18 @@ def validator(instance, att, val): def positive(instance, att, x): - """An attrs validator that checks a value is positive.""" + """Check that a value is positive. + + This is an attrs validator. + """ assert x > 0, "must be positive" def nonnegative(instance, att, x): - """An attrs validator that checks a value is non-negative.""" + """Check that a value is non-negative. + + This is an attrs validator. + """ assert x >= 0, "must be non-negative" diff --git a/src/py21cmsense/antpos.py b/src/py21cmsense/antpos.py index dbed0ab..ec4c025 100644 --- a/src/py21cmsense/antpos.py +++ b/src/py21cmsense/antpos.py @@ -9,7 +9,6 @@ import numpy as np from astropy import units as un -from typing import Sequence from . import units as tp from . import yaml @@ -52,10 +51,7 @@ def hera( sep = separation.to_value("m") - if row_separation is None: - row_sep = sep * np.sqrt(3) / 2 - else: - row_sep = row_separation.to_value("m") + row_sep = sep * np.sqrt(3) / 2 if row_separation is None else row_separation.to_value("m") # construct the main hexagon positions = [] @@ -97,11 +93,7 @@ def hera( exterior_hex_num = outriggers + 2 for row in range(exterior_hex_num - 1, -exterior_hex_num, -1): for col in range(2 * exterior_hex_num - abs(row) - 1): - x_pos = ( - ((2 - (2 * exterior_hex_num - abs(row))) / 2 + col) - * sep - * (hex_num - 1) - ) + x_pos = ((2 - (2 * exterior_hex_num - abs(row))) / 2 + col) * sep * (hex_num - 1) y_pos = row * (hex_num - 1) * row_sep theta = np.arctan2(y_pos, x_pos) if np.sqrt(x_pos**2 + y_pos**2) > sep * (hex_num + 1): diff --git a/src/py21cmsense/baseline_filters.py b/src/py21cmsense/baseline_filters.py index 209755c..f7b1b8b 100644 --- a/src/py21cmsense/baseline_filters.py +++ b/src/py21cmsense/baseline_filters.py @@ -7,11 +7,10 @@ """ import abc + import attr import numpy as np -import warnings from astropy import units as un -from pathlib import Path from . import units as tp @@ -44,29 +43,22 @@ def __call__(self, bl: tp.Length) -> bool: bool True if the baseline should be included. """ - pass # pragma: no cover + # pragma: no cover @attr.define class BaselineRange(BaselineFilter): """Theory model from EOS2021 (https://arxiv.org/abs/2110.13919).""" - bl_min: tp.Length = attr.field( - default=0 * un.m, validator=tp.vld_physical_type("length") - ) - bl_max: tp.Length = attr.field( - default=np.inf * un.m, validator=tp.vld_physical_type("length") - ) - direction: str = attr.field( - default="mag", validator=attr.validators.in_(("ew", "ns", "mag")) - ) + bl_min: tp.Length = attr.field(default=0 * un.m, validator=tp.vld_physical_type("length")) + bl_max: tp.Length = attr.field(default=np.inf * un.m, validator=tp.vld_physical_type("length")) + direction: str = attr.field(default="mag", validator=attr.validators.in_(("ew", "ns", "mag"))) @bl_max.validator def _bl_max_vld(self, att, val): if val <= self.bl_min: raise ValueError( - "bl_max must be greater than bl_min, got " - f"bl_min={self.bl_min} and bl_max={val}" + "bl_max must be greater than bl_min, got " f"bl_min={self.bl_min} and bl_max={val}" ) def __call__(self, bl: tp.Length) -> bool: diff --git a/src/py21cmsense/beam.py b/src/py21cmsense/beam.py index 15fbddc..dbe9673 100644 --- a/src/py21cmsense/beam.py +++ b/src/py21cmsense/beam.py @@ -2,8 +2,9 @@ from __future__ import annotations +from abc import ABCMeta, abstractmethod + import attr -from abc import ABCMeta, abstractmethod, abstractproperty from astropy import constants as cnst from astropy import units as un from hickleable import hickleable @@ -37,25 +38,25 @@ def at(self, frequency: tp.Frequency) -> PrimaryBeam: """Get a copy of the object at a new frequency.""" return attr.evolve(self, frequency=frequency) - @abstractproperty + @property + @abstractmethod def area(self) -> un.Quantity[un.steradian]: """Beam area [units: sr].""" - pass - @abstractproperty + @property + @abstractmethod def width(self) -> un.Quantity[un.radians]: """Beam width [units: rad].""" - pass - @abstractproperty + @property + @abstractmethod def first_null(self) -> un.Quantity[un.radians]: """An approximation of the first null of the beam.""" - pass - @abstractproperty + @property + @abstractmethod def sq_area(self) -> un.Quantity[un.steradian]: """The area of the beam^2.""" - pass @property def b_eff(self) -> un.Quantity[un.steradian]: @@ -65,10 +66,10 @@ def b_eff(self) -> un.Quantity[un.steradian]: """ return self.area**2 / self.sq_area - @abstractproperty + @property + @abstractmethod def uv_resolution(self) -> un.Quantity[1 / un.radians]: """The UV footprint of the beam.""" - pass @classmethod def from_uvbeam(cls) -> PrimaryBeam: @@ -91,9 +92,7 @@ class GaussianBeam(PrimaryBeam): otherwise defined. This generates the beam size. """ - dish_size: tp.Length = attr.ib( - validator=(tp.vld_physical_type("length"), ut.positive) - ) + dish_size: tp.Length = attr.ib(validator=(tp.vld_physical_type("length"), ut.positive)) @property def wavelength(self) -> un.Quantity[un.m]: diff --git a/src/py21cmsense/cli.py b/src/py21cmsense/cli.py index 50cc56b..d524757 100644 --- a/src/py21cmsense/cli.py +++ b/src/py21cmsense/cli.py @@ -1,14 +1,13 @@ """CLI routines for 21cmSense.""" -import click import logging -import os -import pickle import tempfile -from astropy.io.misc import yaml -from hickle import hickle from os import path from pathlib import Path + +import click +from astropy.io.misc import yaml +from hickle import hickle from rich.logging import RichHandler from . import observation @@ -24,9 +23,7 @@ main = click.Group() FORMAT = "%(message)s" -logging.basicConfig( - level=logging.INFO, format=FORMAT, datefmt="[%X]", handlers=[RichHandler()] -) +logging.basicConfig(level=logging.INFO, format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]) logger = logging.getLogger("py21cmsense") @@ -78,12 +75,8 @@ def grid_baselines(configfile, direc, outfile): help="directory to save output file", default=".", ) -@click.option( - "--fname", default=None, type=click.Path(), help="filename to save output file" -) -@click.option( - "--thermal/--no-thermal", default=True, help="whether to include thermal noise" -) +@click.option("--fname", default=None, type=click.Path(), help="filename to save output file") +@click.option("--thermal/--no-thermal", default=True, help="whether to include thermal noise") @click.option( "--samplevar/--no-samplevar", default=True, @@ -100,12 +93,8 @@ def grid_baselines(configfile, direc, outfile): default=True, help="whether to plot the 1D power spectrum uncertainty", ) -@click.option( - "--plot-title", default=None, type=str, help="title for the output 1D plot" -) -@click.option( - "--prefix", default="", type=str, help="string prefix for all output files" -) +@click.option("--plot-title", default=None, type=str, help="title for the output 1D plot") +@click.option("--prefix", default="", type=str, help="string prefix for all output files") def calc_sense( configfile, array_file, @@ -146,9 +135,7 @@ def calc_sense( f"Used {len(sensitivity.k1d)} bins between " f"{sensitivity.k1d.min()} and {sensitivity.k1d.max()}" ) - sensitivity.write( - filename=fname, thermal=thermal, sample=samplevar, direc=direc, prefix=prefix - ) + sensitivity.write(filename=fname, thermal=thermal, sample=samplevar, direc=direc, prefix=prefix) if write_significance: sig = sensitivity.calculate_significance(thermal=thermal, sample=samplevar) @@ -158,7 +145,7 @@ def calc_sense( fig = sensitivity.plot_sense_1d(thermal=thermal, sample=samplevar) if plot_title: plt.title(plot_title) - f"{prefix}_" if prefix else "" + prefix = f"{prefix}_" if prefix else "" fig.savefig( f"{direc}/{prefix}{sensitivity.foreground_model}_" f"{sensitivity.observation.frequency:.3f}.png" diff --git a/src/py21cmsense/conversions.py b/src/py21cmsense/conversions.py index 21e5b86..4afc95a 100644 --- a/src/py21cmsense/conversions.py +++ b/src/py21cmsense/conversions.py @@ -4,12 +4,13 @@ Provides conversions between observing co-ordinates and cosmological co-ordinates. """ +from __future__ import annotations + import numpy as np from astropy import constants as cnst from astropy import units as un from astropy.cosmology import FLRW, Planck15 from astropy.cosmology.units import littleh -from typing import Union from . import units as tp @@ -17,7 +18,6 @@ f21 = 1.42040575177 * un.GHz -@un.quantity_input def f2z(fq: tp.Frequency) -> float: """ Convert frequency to redshift for 21 cm line. @@ -34,8 +34,7 @@ def f2z(fq: tp.Frequency) -> float: return float(f21 / fq - 1) -@un.quantity_input -def z2f(z: Union[float, np.array]) -> un.Quantity[un.GHz]: +def z2f(z: float | np.array) -> un.Quantity[un.GHz]: """ Convert redshift to z=0 frequency for 21 cm line. @@ -52,7 +51,7 @@ def z2f(z: Union[float, np.array]) -> un.Quantity[un.GHz]: def dL_dth( - z: Union[float, np.array], + z: float | np.array, cosmo: FLRW = Planck15, approximate=False, ) -> un.Quantity[un.Mpc / un.rad / littleh]: @@ -74,17 +73,13 @@ def dL_dth( From Furlanetto et al. (2006) """ if approximate: - return ( - (1.9 * (1.0 / un.arcmin) * ((1 + z) / 10.0) ** 0.2).to(1 / un.rad) - * un.Mpc - / littleh - ) + return (1.9 * (1.0 / un.arcmin) * ((1 + z) / 10.0) ** 0.2).to(1 / un.rad) * un.Mpc / littleh else: return cosmo.h * cosmo.comoving_transverse_distance(z) / un.rad / littleh def dL_df( - z: Union[float, np.array], + z: float | np.array, cosmo: FLRW = Planck15, approximate=False, ) -> un.Quantity[un.Mpc / un.MHz / littleh]: @@ -112,7 +107,7 @@ def dL_df( def dk_du( - z: Union[float, np.array], + z: float | np.array, cosmo: FLRW = Planck15, approximate=False, ) -> un.Quantity[littleh / un.Mpc]: @@ -133,7 +128,7 @@ def dk_du( def dk_deta( - z: Union[float, np.array], + z: float | np.array, cosmo: FLRW = Planck15, approximate=False, ) -> un.Quantity[un.MHz * littleh / un.Mpc]: @@ -149,7 +144,7 @@ def dk_deta( def X2Y( - z: Union[float, np.array], + z: float | np.array, cosmo: FLRW = Planck15, approximate=False, ) -> un.Quantity[un.Mpc**3 / littleh**3 / un.steradian / un.MHz]: @@ -167,6 +162,4 @@ def X2Y( ------- astropy.Quantity: the conversion factor. Units are Mpc^3/h^3 / (sr MHz). """ - return dL_dth(z, cosmo, approximate=approximate) ** 2 * dL_df( - z, cosmo, approximate=approximate - ) + return dL_dth(z, cosmo, approximate=approximate) ** 2 * dL_df(z, cosmo, approximate=approximate) diff --git a/src/py21cmsense/observation.py b/src/py21cmsense/observation.py index 5c3c33e..1bef077 100644 --- a/src/py21cmsense/observation.py +++ b/src/py21cmsense/observation.py @@ -2,18 +2,19 @@ from __future__ import annotations -import attr import collections +from collections import defaultdict +from functools import cached_property +from os import path +from typing import Any, Callable + +import attr import numpy as np from astropy import units as un from astropy.cosmology import LambdaCDM, Planck15 from astropy.io.misc import yaml from attr import validators as vld -from collections import defaultdict -from functools import cached_property, partial from hickleable import hickleable -from os import path -from typing import Any, Callable from . import _utils as ut from . import conversions as conv @@ -120,9 +121,7 @@ class Observation: # The following defaults are based on Mozdzen et al. 2017: 2017MNRAS.464.4995M, # figure 8, with galaxy down. - spectral_index: float = attr.ib( - default=2.6, converter=float, validator=ut.between(1.5, 4) - ) + spectral_index: float = attr.ib(default=2.6, converter=float, validator=ut.between(1.5, 4)) tsky_amplitude: tp.Temperature = attr.ib( default=260000 * un.mK, validator=ut.nonnegative, @@ -140,18 +139,14 @@ def from_yaml(cls, yaml_file): elif isinstance(yaml_file, collections.abc.Mapping): data = yaml_file else: - raise ValueError( - "yaml_file must be a string filepath or a raw dict from such a file." - ) + raise ValueError("yaml_file must be a string filepath or a raw dict from such a file.") if ( isinstance(data["observatory"], str) and isinstance(yaml_file, str) and not path.isabs(data["observatory"]) ): - data["observatory"] = path.join( - path.dirname(yaml_file), data["observatory"] - ) + data["observatory"] = path.join(path.dirname(yaml_file), data["observatory"]) observatory = obs.Observatory.from_yaml(data.pop("observatory")) return cls(observatory=observatory, **data) @@ -311,9 +306,7 @@ def redshift(self) -> float: @cached_property def eta(self) -> un.Quantity[1 / un.MHz]: """The fourier dual of the frequencies of the observation.""" - return np.fft.fftfreq( - self.n_channels, self.bandwidth.to("MHz") / self.n_channels - ) + return np.fft.fftfreq(self.n_channels, self.bandwidth.to("MHz") / self.n_channels) @cached_property def kparallel(self) -> un.Quantity[un.littleh / un.Mpc]: @@ -322,9 +315,7 @@ def kparallel(self) -> un.Quantity[un.littleh / un.Mpc]: Order of the values is the same as `fftfreq` (i.e. zero-first) """ return ( - conv.dk_deta( - self.redshift, self.cosmo, approximate=self.use_approximate_cosmo - ) + conv.dk_deta(self.redshift, self.cosmo, approximate=self.use_approximate_cosmo) * self.eta ) diff --git a/src/py21cmsense/observatory.py b/src/py21cmsense/observatory.py index 115141f..876b988 100644 --- a/src/py21cmsense/observatory.py +++ b/src/py21cmsense/observatory.py @@ -7,20 +7,21 @@ from __future__ import annotations -import attr import collections import logging +from collections import defaultdict +from functools import cached_property +from pathlib import Path +from typing import Callable + +import attr import numpy as np import tqdm from astropy import constants as cnst from astropy import units as un from astropy.io.misc import yaml from attr import validators as vld -from cached_property import cached_property -from collections import defaultdict from hickleable import hickleable -from pathlib import Path -from typing import Callable from . import _utils as ut from . import beam, config @@ -118,12 +119,11 @@ def _trcv_vld(self, att, val): y = val(1 * un.MHz) except Exception as e: raise ValueError( - "Trcv function must take a frequency Quantity and return a temperature Quantity." + "Trcv function must take a frequency Quantity and " + "return a temperature Quantity." ) from e - if not ( - isinstance(y, un.Quantity) and y.unit.physical_type == "temperature" - ): + if not (isinstance(y, un.Quantity) and y.unit.physical_type == "temperature"): raise ValueError("Trcv function must return a temperature Quantity.") else: tp.vld_physical_type("temperature")(self, att, val) @@ -153,9 +153,7 @@ def from_uvdata(cls, uvdata, beam: beam.PrimaryBeam, **kwargs) -> Observatory: ) @classmethod - def from_yaml( - cls, yaml_file: str | dict, frequency: tp.Frequency | None = None - ) -> Observatory: + def from_yaml(cls, yaml_file: str | dict, frequency: tp.Frequency | None = None) -> Observatory: """Instantiate an Observatory from a compatible YAML config file.""" if isinstance(yaml_file, (str, Path)): with open(yaml_file) as fl: @@ -165,9 +163,7 @@ def from_yaml( elif isinstance(yaml_file, collections.abc.Mapping): data = yaml_file else: - raise ValueError( - "yaml_file must be a string filepath or a raw dict from such a file." - ) + raise ValueError("yaml_file must be a string filepath or a raw dict from such a file.") # Mask out some antennas if a max_antpos is set in the YAML max_antpos = data.pop("max_antpos", np.inf * un.m) @@ -195,9 +191,7 @@ def from_yaml( return cls(antpos=antpos, beam=_beam, **data) @classmethod - def from_profile( - cls, profile: str, frequency: tp.Frequency | None = None, **kwargs - ): + def from_profile(cls, profile: str, frequency: tp.Frequency | None = None, **kwargs): """Instantiate the Observatory from a builtin profile. Parameters @@ -235,7 +229,7 @@ def baselines_metres(self) -> tp.Meters: def projected_baselines( self, baselines: tp.Length | None = None, time_offset: tp.Time = 0 * un.hour ) -> np.ndarray: - """The *projected* baseline lengths (in wavelengths). + """Compute the *projected* baseline lengths (in wavelengths). Phased to a point that has rotated off zenith by some time_offset. @@ -322,10 +316,7 @@ def get_redundant_baselines( baseline_filters = tp._tuplify(baseline_filters, 1) def filt(blm): - for filt in baseline_filters: - if not filt(blm): - return False - return True + return all(filt(blm) for filt in baseline_filters) # Everything here is in wavelengths uvw = self.projected_baselines()[:, :, :2].value @@ -485,13 +476,9 @@ def grid_baselines( bl_max = np.sqrt(np.max(np.sum(baselines**2, axis=1))) if weights is None: - raise ValueError( - "If baselines are provided, weights must also be provided." - ) + raise ValueError("If baselines are provided, weights must also be provided.") - time_offsets = self.time_offsets_from_obs_int_time( - integration_time, observation_duration - ) + time_offsets = self.time_offsets_from_obs_int_time(integration_time, observation_duration) uvws = self.projected_baselines(baselines, time_offsets).reshape( baselines.shape[0], time_offsets.size, 3 @@ -545,9 +532,7 @@ def ugrid_edges(self, bl_max: tp.Length = np.inf * un.m) -> np.ndarray: bl_max = self.longest_used_baseline(bl_max) # We're doing edges of bins here, and the first edge is at uv_res/2 - n_positive = int( - np.ceil((bl_max - self.beam.uv_resolution / 2) / self.beam.uv_resolution) - ) + n_positive = int(np.ceil((bl_max - self.beam.uv_resolution / 2) / self.beam.uv_resolution)) # Grid from uv_res/2 to just past (or equal to) bl_max, in steps of resolution. positive = np.linspace( diff --git a/src/py21cmsense/sensitivity.py b/src/py21cmsense/sensitivity.py index db587c6..b91145e 100644 --- a/src/py21cmsense/sensitivity.py +++ b/src/py21cmsense/sensitivity.py @@ -10,11 +10,17 @@ from __future__ import annotations +import importlib +import logging +from collections.abc import Mapping +from functools import cached_property +from os import path +from pathlib import Path +from typing import Callable + import attr import h5py import hickle -import importlib -import logging import numpy as np import tqdm from astropy import units as un @@ -22,13 +28,8 @@ from astropy.cosmology.units import littleh, with_H0 from astropy.io.misc import yaml from attr import validators as vld -from cached_property import cached_property -from collections.abc import Mapping from hickleable import hickleable from methodtools import lru_cache -from os import path -from pathlib import Path -from typing import Callable from . import _utils as ut from . import config @@ -67,9 +68,7 @@ def _load_yaml(yaml_file): elif isinstance(yaml_file, Mapping): data = yaml_file else: - raise ValueError( - "yaml_file must be a string filepath or a raw dict from such a file." - ) + raise ValueError("yaml_file must be a string filepath or a raw dict from such a file.") return data @classmethod @@ -89,9 +88,7 @@ def from_yaml(cls, yaml_file) -> Sensitivity: elif h5py.is_hdf5(obsfile): observation = hickle.load(obsfile) else: - raise ValueError( - "observation must be a filename with extension .yml or .h5" - ) + raise ValueError("observation must be a filename with extension .yml or .h5") return klass(observation=observation, **data) @@ -184,7 +181,7 @@ def from_yaml(cls, yaml_file) -> Sensitivity: for mdl in data.pop("plugins"): try: importlib.import_module(mdl) - except Exception as e: + except Exception as e: # noqa: PERF203 raise ImportError(f"Could not import {mdl}") from e if "theory_model" in data: @@ -309,15 +306,9 @@ def _nsamples_2d( hor = self.horizon_limit(umag) if k_perp not in sense["thermal"]: - sense["thermal"][k_perp] = ( - np.zeros(len(self.observation.kparallel)) / un.mK**4 - ) - sense["sample"][k_perp] = ( - np.zeros(len(self.observation.kparallel)) / un.mK**4 - ) - sense["both"][k_perp] = ( - np.zeros(len(self.observation.kparallel)) / un.mK**4 - ) + sense["thermal"][k_perp] = np.zeros(len(self.observation.kparallel)) / un.mK**4 + sense["sample"][k_perp] = np.zeros(len(self.observation.kparallel)) / un.mK**4 + sense["both"][k_perp] = np.zeros(len(self.observation.kparallel)) / un.mK**4 # Exclude parallel modes dominated by foregrounds kpars = self.observation.kparallel[self.observation.kparallel >= hor] @@ -381,7 +372,7 @@ def calculate_sensitivity_2d( # errors were added in inverse quadrature, now need to invert and take # square root to have error bars; also divide errors by number of indep. fields final_sense = {} - for k_perp in sense.keys(): + for k_perp in sense: mask = sense[k_perp] > 0 if self.systematics_mask is not None: mask &= self.systematics_mask(k_perp, self.observation.kparallel) @@ -392,16 +383,12 @@ def calculate_sensitivity_2d( final_sense[k_perp] = np.inf * np.ones(len(mask)) * un.mK**2 if thermal: total_std = thermal_std = 1 / np.sqrt( - self._nsamples_2d["thermal"][k_perp][mask] - * self.observation.n_lst_bins + self._nsamples_2d["thermal"][k_perp][mask] * self.observation.n_lst_bins ) if sample: total_std = sample_std = 1 / np.sqrt( self._nsamples_2d["sample"][k_perp][mask] - * ( - self.observation.time_per_day - / self.observation.beam_crossing_time - ).to("") + * (self.observation.time_per_day / self.observation.beam_crossing_time).to("") ) if thermal and sample: total_std = thermal_std + sample_std @@ -426,9 +413,7 @@ def calculate_sensitivity_2d_grid( kpar_edges The edges of the bins in kpar. """ - sense2d_inv = np.zeros((len(kperp_edges) - 1, len(kpar_edges) - 1)) << ( - 1 / un.mK**4 - ) + sense2d_inv = np.zeros((len(kperp_edges) - 1, len(kpar_edges) - 1)) << (1 / un.mK**4) sense = self.calculate_sensitivity_2d(thermal=thermal, sample=sample) assert np.all(np.diff(kperp_edges) > 0) @@ -450,9 +435,7 @@ def calculate_sensitivity_2d_grid( good_ks = kpar_indx >= 0 good_ks &= kpar_indx < len(kpar_edges) - 1 - sense2d_inv[kperp_indx][kpar_indx[good_ks]] += ( - 1.0 / sense[k_perp][good_ks] ** 2 - ) + sense2d_inv[kperp_indx][kpar_indx[good_ks]] += 1.0 / sense[k_perp][good_ks] ** 2 # invert errors and take square root again for final answer sense2d = np.ones(sense2d_inv.shape) * un.mK**2 * np.inf @@ -518,9 +501,7 @@ def _average_sense_to_1d( return sense1d @lru_cache() - def calculate_sensitivity_1d( - self, thermal: bool = True, sample: bool = True - ) -> tp.Delta: + def calculate_sensitivity_1d(self, thermal: bool = True, sample: bool = True) -> tp.Delta: """Calculate a 1D sensitivity curve. Parameters @@ -553,9 +534,7 @@ def delta_squared(self) -> tp.Delta: return self.theory_model.delta_squared(self.observation.redshift, k) @lru_cache() - def calculate_significance( - self, thermal: bool = True, sample: bool = True - ) -> float: + def calculate_significance(self, thermal: bool = True, sample: bool = True) -> float: """ Calculate significance of a detection of the default cosmological power spectrum. @@ -581,8 +560,8 @@ def plot_sense_2d(self, sense2d: dict[tp.Wavenumber, tp.Delta]): """Create a colormap plot of the sensitivity un UV bins.""" try: import matplotlib.pyplot as plt - except ImportError: # pragma: no cover - raise ImportError("matplotlib is required to make plots...") + except ImportError as e: # pragma: no cover + raise ImportError("matplotlib is required to make plots...") from e keys = sorted(sense2d.keys()) x = np.array([v.value for v in keys]) @@ -609,7 +588,7 @@ def write( filename: str | Path, thermal: bool = True, sample: bool = True, - prefix: str = None, + prefix: str | None = None, direc: str | Path = ".", ) -> Path: """Save sensitivity results to HDF5 file. @@ -656,8 +635,8 @@ def plot_sense_1d(self, sample: bool = True, thermal: bool = True): """Create a plot of the sensitivity in 1D k-bins.""" try: import matplotlib.pyplot as plt - except ImportError: # pragma: no cover - raise ImportError("matplotlib is required to make plots...") + except ImportError as e: # pragma: no cover + raise ImportError("matplotlib is required to make plots...") from e out = self._get_all_sensitivity_combos(thermal, sample) for key, value in out.items(): @@ -671,16 +650,12 @@ def plot_sense_1d(self, sample: bool = True, thermal: bool = True): return plt.gcf() - def _get_all_sensitivity_combos( - self, thermal: bool, sample: bool - ) -> dict[str, tp.Delta]: + def _get_all_sensitivity_combos(self, thermal: bool, sample: bool) -> dict[str, tp.Delta]: result = {} if thermal: result["thermal_noise"] = self.calculate_sensitivity_1d(sample=False) if sample: - result["sample_noise"] = self.calculate_sensitivity_1d( - thermal=False, sample=True - ) + result["sample_noise"] = self.calculate_sensitivity_1d(thermal=False, sample=True) if thermal and sample: result["sample+thermal_noise"] = self.calculate_sensitivity_1d( diff --git a/src/py21cmsense/theory.py b/src/py21cmsense/theory.py index c1cf28c..7f209a3 100644 --- a/src/py21cmsense/theory.py +++ b/src/py21cmsense/theory.py @@ -12,10 +12,11 @@ """ import abc -import numpy as np import warnings -from astropy import units as un from pathlib import Path + +import numpy as np +from astropy import units as un from scipy.interpolate import InterpolatedUnivariateSpline, RectBivariateSpline _ALL_THEORY_POWER_SPECTRA = {} @@ -52,7 +53,7 @@ def delta_squared(self, z: float, k: np.ndarray) -> un.Quantity[un.mK**2]: delta_squared An array of delta_squared values in units of mK^2. """ - pass # pragma: no cover + # pragma: no cover class TheorySpline(TheoryModel): @@ -86,7 +87,8 @@ def delta_squared(self, z: float, k: np.ndarray) -> un.Quantity[un.mK**2]: ) if not self.z.min() <= z <= self.z.max(): warnings.warn( - f"Extrapolating beyond simulated redshift range: {z} not in range ({self.z.min(), self.z.max()})", + f"Extrapolating beyond simulated redshift range: {z} not in " + f"range ({self.z.min(), self.z.max()})", stacklevel=2, ) diff --git a/src/py21cmsense/units.py b/src/py21cmsense/units.py index c76a9a0..279136b 100644 --- a/src/py21cmsense/units.py +++ b/src/py21cmsense/units.py @@ -2,12 +2,12 @@ from __future__ import annotations +from typing import Any, Callable + import attr -import numpy as np from astropy import constants as cnst from astropy import units as un from astropy.cosmology.units import littleh, redshift -from typing import Any, Callable, Type, Union un.add_enabled_units([littleh, redshift]) @@ -15,8 +15,6 @@ class UnitError(ValueError): """An error pertaining to having incorrect units.""" - pass - Length = un.Quantity["length"] Meters = un.Quantity["m"] @@ -60,9 +58,7 @@ def _check_unit(self, att, val): raise UnitError(f"{att.name} must be an astropy Quantity!") if not val.unit.is_equivalent(unit, equivalencies): - raise un.UnitConversionError( - f"{att.name} not convertible to {unit}. Got {val.unit}" - ) + raise un.UnitConversionError(f"{att.name} not convertible to {unit}. Got {val.unit}") return _check_unit diff --git a/src/py21cmsense/yaml.py b/src/py21cmsense/yaml.py index 505ce90..ab70019 100644 --- a/src/py21cmsense/yaml.py +++ b/src/py21cmsense/yaml.py @@ -1,12 +1,13 @@ """Module defining new YAML tags for py21cmsense.""" import inspect -import numpy as np import pickle +from functools import wraps + +import numpy as np import yaml from astropy import units as un from astropy.io.misc.yaml import AstropyLoader -from functools import wraps _DATA_LOADERS = {} @@ -16,11 +17,9 @@ class LoadError(IOError): """Error raised on trying to load data from YAML files.""" - pass - def data_loader(tag=None): - """A decorator that turns a function into a YAML tag for loading external datafiles. + """Convert a function into a YAML tag for loading external datafiles. The form of the tag is:: @@ -43,7 +42,7 @@ def wrapper(data): except OSError: raise except Exception as e: - raise LoadError(str(e)) + raise LoadError(str(e)) from e def yaml_fnc(loader, node): args = node.value.split("|") @@ -94,7 +93,7 @@ def txt_loader(data): def yaml_func(tag=None): - """A decorator that turns a function into a YAML tag.""" + """Convert a function into a YAML tag.""" def inner(fnc): new_tag = tag or fnc.__name__ diff --git a/tests/conftest.py b/tests/conftest.py index 6b83bf9..27ac935 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ -import pytest +"""Pytest configuration file.""" from pathlib import Path +import pytest + @pytest.fixture(scope="session") def tmpdirec(tmpdir_factory): diff --git a/tests/test_antpos.py b/tests/test_antpos.py index 8defa57..c2f15f1 100644 --- a/tests/test_antpos.py +++ b/tests/test_antpos.py @@ -1,8 +1,8 @@ -import pytest +"""Test the antenna positions.""" import numpy as np +import pytest from astropy import units as un - from py21cmsense.antpos import hera @@ -37,5 +37,5 @@ def test_hera_set_row_sep(): def test_bad_hex_num(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="hex_num must be greater than 1"): hera(1) diff --git a/tests/test_baseline_filters.py b/tests/test_baseline_filters.py index 0b419b5..1b93e7b 100644 --- a/tests/test_baseline_filters.py +++ b/tests/test_baseline_filters.py @@ -1,10 +1,8 @@ """Test baseline_filters module.""" -import pytest - import numpy as np +import pytest from astropy import units as un - from py21cmsense.baseline_filters import BaselineRange # Test IDs for parametrization @@ -23,12 +21,6 @@ "magnitude_zero", ] -error_case_ids = [ - "bl_max_less_than_bl_min", - "invalid_direction", - "non_length_input", -] - # Happy path test values happy_path_values = [ (0 * un.m, np.inf * un.m, "ew", np.array([1, 0, 0]) * un.m, True), @@ -46,16 +38,9 @@ (0 * un.m, np.inf * un.m, "mag", np.array([0, 0, 0]) * un.m, True), ] -# Error case test values -error_case_values = [ - (2 * un.m, 1 * un.m, "mag", np.array([1, 1, 0]) * un.m), - (0 * un.m, np.inf * un.m, "invalid", np.array([1, 1, 0]) * un.m), - (0 * un.m, np.inf * un.m, "ew", np.array([1, 1, 0])), -] - @pytest.mark.parametrize( - "bl_min, bl_max, direction, baseline, expected", + ("bl_min", "bl_max", "direction", "baseline", "expected"), happy_path_values, ids=happy_path_ids, ) @@ -71,7 +56,9 @@ def test_happy_path(bl_min, bl_max, direction, baseline, expected): @pytest.mark.parametrize( - "bl_min, bl_max, direction, baseline, expected", edge_case_values, ids=edge_case_ids + ("bl_min", "bl_max", "direction", "baseline", "expected"), + edge_case_values, + ids=edge_case_ids, ) def test_edge_cases(bl_min, bl_max, direction, baseline, expected): # Arrange @@ -84,13 +71,14 @@ def test_edge_cases(bl_min, bl_max, direction, baseline, expected): assert result == expected -@pytest.mark.parametrize( - "bl_min, bl_max, direction, baseline", error_case_values, ids=error_case_ids -) -def test_error_cases(bl_min, bl_max, direction, baseline): - # Arrange - with pytest.raises(ValueError): - baseline_range = BaselineRange( - bl_min=bl_min, bl_max=bl_max, direction=direction - ) - baseline_range(baseline) +def test_error_cases(): + with pytest.raises(ValueError, match="bl_max must be greater than bl_min"): + baseline_range = BaselineRange(bl_min=2 * un.m, bl_max=1 * un.m, direction="mag") + + with pytest.raises(ValueError, match="must be in"): + baseline_range = BaselineRange(bl_min=1 * un.m, bl_max=2 * un.m, direction="invalid") + + baseline_range = BaselineRange(bl_min=1 * un.m, bl_max=2 * un.m, direction="ew") + + with pytest.raises(ValueError, match="Can only apply"): + baseline_range(np.array([1, 1, 0])) diff --git a/tests/test_beam.py b/tests/test_beam.py index e4b4e13..741d1f4 100644 --- a/tests/test_beam.py +++ b/tests/test_beam.py @@ -1,8 +1,8 @@ -import pytest +"""Test the beam module.""" import numpy as np +import pytest from astropy import units - from py21cmsense import GaussianBeam, beam diff --git a/tests/test_cli.py b/tests/test_cli.py index 17b1fa4..7bcb6c7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,12 +1,11 @@ -import pytest +"""Test the CLI module.""" -import glob import traceback -from astropy.io.misc import yaml -from click.testing import CliRunner from os import path -from yaml import dump +import pytest +from astropy.io.misc import yaml +from click.testing import CliRunner from py21cmsense import cli here = path.dirname(path.abspath(__file__)) @@ -60,9 +59,7 @@ def test_gridding_baselines(runner, observation_config, tmpdirec): def test_calc_sense(runner, sensitivity_config, tmpdirec): - output = runner.invoke( - cli.main, ["calc-sense", sensitivity_config, "--direc", str(tmpdirec)] - ) + output = runner.invoke(cli.main, ["calc-sense", sensitivity_config, "--direc", str(tmpdirec)]) if output.exception: traceback.print_exception(*output.exc_info) diff --git a/tests/test_conversions.py b/tests/test_conversions.py index b80f5a2..7e7305c 100644 --- a/tests/test_conversions.py +++ b/tests/test_conversions.py @@ -1,9 +1,8 @@ -import pytest +"""Test the conversions module.""" import numpy as np from astropy import units from astropy.cosmology import Planck15 - from py21cmsense import conversions as cnv diff --git a/tests/test_io.py b/tests/test_io.py index 85105f4..f6211cb 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,11 +1,13 @@ +"""Test the I/O module.""" + import hickle import numpy as np from astropy import units as un - from py21cmsense import GaussianBeam, Observation, Observatory +rng = np.random.default_rng(1234) beam = GaussianBeam(frequency=150 * un.MHz, dish_size=14 * un.m) -obs = Observatory(beam=beam, antpos=np.random.random((25, 3)) * 30 * un.m) +obs = Observatory(beam=beam, antpos=rng.random((25, 3)) * 30 * un.m) observation = Observation(observatory=obs) diff --git a/tests/test_observation.py b/tests/test_observation.py index 4ce61e7..675218d 100644 --- a/tests/test_observation.py +++ b/tests/test_observation.py @@ -1,11 +1,12 @@ -import pytest +"""Tests for the Observation class.""" import copy -import numpy as np import pickle + +import numpy as np +import pytest from astropy import units from astropy.cosmology.units import littleh - from py21cmsense import GaussianBeam, Observation, Observatory @@ -68,10 +69,11 @@ def test_equality(observatory): def test_from_yaml(observatory): + rng = np.random.default_rng(1234) obs = Observation.from_yaml( { "observatory": { - "antpos": np.random.random((20, 3)) * units.m, + "antpos": rng.random((20, 3)) * units.m, "beam": { "class": "GaussianBeam", "frequency": 150 * units.MHz, diff --git a/tests/test_observatory.py b/tests/test_observatory.py index a19001f..8d4b20d 100644 --- a/tests/test_observatory.py +++ b/tests/test_observatory.py @@ -1,12 +1,13 @@ -import pytest +"""Test the observatory module.""" + +import re +from pathlib import Path import numpy as np +import pytest import pyuvdata -import re from astropy import units from astropy.coordinates import EarthLocation -from pathlib import Path - from py21cmsense import Observatory from py21cmsense.baseline_filters import BaselineRange from py21cmsense.beam import GaussianBeam @@ -35,9 +36,7 @@ def test_antpos(bm): with pytest.raises(ValueError, match="antpos must be a 2D array"): Observatory(antpos=np.zeros(10) * units.m, beam=bm) - with pytest.raises( - ValueError, match=re.escape("antpos must have shape (Nants, 3)") - ): + with pytest.raises(ValueError, match=re.escape("antpos must have shape (Nants, 3)")): Observatory(antpos=np.zeros((10, 2)) * units.m, beam=bm) @@ -70,21 +69,13 @@ def test_observatory(bm): a = Observatory(antpos=np.zeros((3, 3)) * units.m, beam=bm) assert a.frequency == bm.frequency assert a.baselines_metres.shape == (3, 3, 3) - assert ( - a.baselines_metres * a.metres_to_wavelengths - ).unit == units.dimensionless_unscaled + assert (a.baselines_metres * a.metres_to_wavelengths).unit == units.dimensionless_unscaled assert a.baseline_lengths.shape == (3, 3) assert np.all(a.baseline_lengths == 0) - b = Observatory( - antpos=np.array([[0, 0, 0], [1, 0, 0], [3, 0, 0]]) * units.m, beam=bm - ) - assert units.isclose( - b.shortest_baseline / b.metres_to_wavelengths, 1 * units.m, rtol=1e-3 - ) - assert units.isclose( - b.longest_baseline / b.metres_to_wavelengths, 3 * units.m, rtol=1e-3 - ) + b = Observatory(antpos=np.array([[0, 0, 0], [1, 0, 0], [3, 0, 0]]) * units.m, beam=bm) + assert units.isclose(b.shortest_baseline / b.metres_to_wavelengths, 1 * units.m, rtol=1e-3) + assert units.isclose(b.longest_baseline / b.metres_to_wavelengths, 3 * units.m, rtol=1e-3) assert b.observation_duration < 1 * units.day assert len(b.get_redundant_baselines()) == 6 # including swapped ones with pytest.raises(AssertionError): @@ -97,16 +88,12 @@ def test_observatory(bm): def test_grid_baselines(bm): - a = Observatory( - antpos=np.random.normal(loc=0, scale=50, size=(20, 3)) * units.m, beam=bm - ) + rng = np.random.default_rng(1234) + a = Observatory(antpos=rng.normal(loc=0, scale=50, size=(20, 3)) * units.m, beam=bm) bl_groups = a.get_redundant_baselines() bl_coords = a.baseline_coords_from_groups(bl_groups) bl_counts = a.baseline_weights_from_groups(bl_groups) - with pytest.raises(ValueError): - a.grid_baselines(bl_coords) - grid0 = a.grid_baselines(coherent=True) grid1 = a.grid_baselines(coherent=True, baselines=bl_coords, weights=bl_counts) assert np.allclose(grid0, grid1) @@ -114,8 +101,7 @@ def test_grid_baselines(bm): def test_min_max_antpos(bm): a = Observatory( - antpos=np.array([np.linspace(0, 50, 11), np.zeros(11), np.zeros(11)]).T - * units.m, + antpos=np.array([np.linspace(0, 50, 11), np.zeros(11), np.zeros(11)]).T * units.m, beam=bm, min_antpos=7 * units.m, ) @@ -123,8 +109,7 @@ def test_min_max_antpos(bm): assert len(a.antpos) == 9 a = Observatory( - antpos=np.array([np.linspace(0, 50, 11), np.zeros(11), np.zeros(11)]).T - * units.m, + antpos=np.array([np.linspace(0, 50, 11), np.zeros(11), np.zeros(11)]).T * units.m, beam=bm, max_antpos=10 * units.m, ) @@ -134,12 +119,8 @@ def test_min_max_antpos(bm): def test_from_uvdata(bm): uv = pyuvdata.UVData() - uv.antenna_positions = ( - np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0], [40, 0, 40]]) * units.m - ) - uv.telescope_location = [ - x.value for x in EarthLocation.from_geodetic(0, 0).to_geocentric() - ] + uv.antenna_positions = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0], [40, 0, 40]]) * units.m + uv.telescope_location = [x.value for x in EarthLocation.from_geodetic(0, 0).to_geocentric()] a = Observatory.from_uvdata(uvdata=uv, beam=bm) assert np.all(a.antpos == uv.antenna_positions) @@ -162,30 +143,24 @@ def test_different_antpos_loaders(tmp_path: Path): value: 14.0 """ - yamlnpy = """ + yamlnpy = f""" antpos: !astropy.units.Quantity unit: !astropy.units.Unit {{unit: m}} - value: !npy {}/antpos.npy - {} - """.format( - tmp_path, - beamtxt, - ) + value: !npy {tmp_path}/antpos.npy + {beamtxt} + """ with open(tmp_path / "npy.yml", "w") as fl: fl.write(yamlnpy) obsnpy = Observatory.from_yaml(tmp_path / "npy.yml") - yamltxt = """ + yamltxt = f""" antpos: !astropy.units.Quantity unit: !astropy.units.Unit {{unit: m}} - value: !txt {}/antpos.txt - {} - """.format( - tmp_path, - beamtxt, - ) + value: !txt {tmp_path}/antpos.txt + {beamtxt} + """ with open(tmp_path / "txt.yml", "w") as fl: fl.write(yamltxt) @@ -195,13 +170,9 @@ def test_different_antpos_loaders(tmp_path: Path): def test_longest_used_baseline(bm): - a = Observatory( - antpos=np.array([[0, 0, 0], [1, 0, 0], [2, 0, 0]]) * units.m, beam=bm - ) + a = Observatory(antpos=np.array([[0, 0, 0], [1, 0, 0], [2, 0, 0]]) * units.m, beam=bm) - assert np.isclose( - a.longest_used_baseline() / a.metres_to_wavelengths, 2 * units.m, atol=1e-3 - ) + assert np.isclose(a.longest_used_baseline() / a.metres_to_wavelengths, 2 * units.m, atol=1e-3) assert np.isclose( a.longest_used_baseline(bl_max=1.5 * units.m) / a.metres_to_wavelengths, 1 * units.m, @@ -210,9 +181,10 @@ def test_longest_used_baseline(bm): def test_from_yaml(bm): + rng = np.random.default_rng(1234) obs = Observatory.from_yaml( { - "antpos": np.random.random((20, 3)) * units.m, + "antpos": rng.random((20, 3)) * units.m, "beam": { "class": "GaussianBeam", "frequency": 150 * units.MHz, @@ -222,21 +194,17 @@ def test_from_yaml(bm): ) assert obs.beam == bm - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="yaml_file must be a string filepath"): Observatory.from_yaml(3) def test_get_redundant_baselines(bm): - a = Observatory( - antpos=np.array([[0, 0, 0], [1, 0, 0], [2, 0, 0]]) * units.m, beam=bm - ) + a = Observatory(antpos=np.array([[0, 0, 0], [1, 0, 0], [2, 0, 0]]) * units.m, beam=bm) reds = a.get_redundant_baselines() assert len(reds) == 4 # len-1, len-2 and backwards - reds = a.get_redundant_baselines( - baseline_filters=BaselineRange(bl_max=1.5 * units.m) - ) + reds = a.get_redundant_baselines(baseline_filters=BaselineRange(bl_max=1.5 * units.m)) assert len(reds) == 2 # len-1 @@ -248,27 +216,25 @@ def test_no_up_coordinate(tmp_path: Path): with open(tmp_path / "mwa_antpos.txt", "w") as fl: np.savetxt(fl, enu[:, :2]) - new_yaml = """ + new_yaml = f""" antpos: !astropy.units.Quantity - value: !txt "%s/mwa_antpos.txt" - unit: !astropy.units.Unit {unit: m} + value: !txt "{tmp_path}/mwa_antpos.txt" + unit: !astropy.units.Unit {{unit: m}} beam: class: GaussianBeam frequency: !astropy.units.Quantity - unit: !astropy.units.Unit {unit: MHz} + unit: !astropy.units.Unit {{unit: MHz}} value: 150 dish_size: !astropy.units.Quantity - unit: !astropy.units.Unit {unit: m} + unit: !astropy.units.Unit {{unit: m}} value: 35 latitude: !astropy.units.Quantity - unit: !astropy.units.Unit {unit: rad} + unit: !astropy.units.Unit {{unit: rad}} value: -0.4681819 Trcv: !astropy.units.Quantity - unit: !astropy.units.Unit {unit: K} + unit: !astropy.units.Unit {{unit: K}} value: 100 -""" % ( - tmp_path - ) +""" with open(tmp_path / "mwa.yaml", "w") as fl: fl.write(new_yaml) diff --git a/tests/test_profiles.py b/tests/test_profiles.py index 1533b4c..dc98204 100644 --- a/tests/test_profiles.py +++ b/tests/test_profiles.py @@ -1,8 +1,8 @@ -import pytest +"""Test the profiles module.""" import astropy.units as un import numpy as np - +import pytest from py21cmsense.observatory import Observatory, get_builtin_profiles diff --git a/tests/test_sensitivity.py b/tests/test_sensitivity.py index d29d615..64bef03 100644 --- a/tests/test_sensitivity.py +++ b/tests/test_sensitivity.py @@ -1,10 +1,11 @@ -import pytest +"""Test the sensitivity module.""" -import numpy as np import warnings + +import numpy as np +import pytest from astropy import units from astropy.cosmology.units import littleh - from py21cmsense import GaussianBeam, Observation, Observatory, PowerSpectrum, theory from py21cmsense.sensitivity import Sensitivity @@ -42,17 +43,14 @@ def test_sensitivity_2d(observation): sense_full = ps.calculate_sensitivity_2d() assert all(np.all(sense_thermal[key] <= sense_full[key]) for key in sense_thermal) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Either thermal or sample must be True"): ps.calculate_sensitivity_2d(thermal=False, sample=False) def test_sensitivity_2d_grid(observation, caplog): ps = PowerSpectrum(observation=observation) sense_ungridded = ps.calculate_sensitivity_2d(thermal=True, sample=True) - kperp = ( - np.array([x.value for x in sense_ungridded.keys()]) - * list(sense_ungridded.keys())[0].unit - ) + kperp = np.array([x.value for x in sense_ungridded]) * next(iter(sense_ungridded.keys())).unit sense = ps.calculate_sensitivity_2d_grid( kperp_edges=np.linspace(kperp.min().value, kperp.max().value, 10) * kperp.unit, kpar_edges=ps.k1d, @@ -62,9 +60,7 @@ def test_sensitivity_2d_grid(observation, caplog): def test_sensitivity_1d_binned(observation): ps = PowerSpectrum(observation=observation) - assert np.all( - ps.calculate_sensitivity_1d() == ps.calculate_sensitivity_1d_binned(ps.k1d) - ) + assert np.all(ps.calculate_sensitivity_1d() == ps.calculate_sensitivity_1d_binned(ps.k1d)) def test_plots(observation): @@ -103,12 +99,13 @@ def test_load_yaml_bad(): ): Sensitivity.from_yaml(1) + rng = np.random.default_rng(1234) with pytest.raises(ImportError, match="Could not import"): PowerSpectrum.from_yaml( { "plugins": ["this.is.not.a.module"], "observatory": { - "antpos": np.random.random((20, 3)) * units.m, + "antpos": rng.random((20, 3)) * units.m, "beam": { "class": "GaussianBeam", "frequency": 150 * units.MHz, @@ -149,14 +146,10 @@ def test_at_freq(observation): ps2 = ps.at_frequency(0.9 * observation.frequency) assert ps2.frequency == 0.9 * observation.frequency - with pytest.warns( - UserWarning, match="Extrapolating above the simulated theoretical" - ): + with pytest.warns(UserWarning, match="Extrapolating above the simulated theoretical"): assert ps.calculate_significance() != ps2.calculate_significance() def test_bad_theory(observation): - with pytest.raises( - ValueError, match="The theory_model must be an instance of TheoryModel" - ): + with pytest.raises(ValueError, match="The theory_model must be an instance of TheoryModel"): PowerSpectrum(observation=observation, theory_model=3) diff --git a/tests/test_theory.py b/tests/test_theory.py index a96664d..56c9626 100644 --- a/tests/test_theory.py +++ b/tests/test_theory.py @@ -1,26 +1,20 @@ -import pytest +"""Test the theory module.""" import numpy as np - +import pytest from py21cmsense.theory import EOS2021, EOS2016Bright, EOS2016Faint, Legacy21cmFAST def test_eos_extrapolation(): eos = EOS2021() - with pytest.warns( - UserWarning, match="Extrapolating above the simulated theoretical k" - ): + with pytest.warns(UserWarning, match="Extrapolating above the simulated theoretical k"): eos.delta_squared(15, np.array([0.1, 1e6])) - with pytest.warns( - UserWarning, match="Extrapolating below the simulated theoretical k" - ): + with pytest.warns(UserWarning, match="Extrapolating below the simulated theoretical k"): eos.delta_squared(15, np.array([0.0001, 0.1])) - with pytest.warns( - UserWarning, match="Extrapolating beyond simulated redshift range" - ): + with pytest.warns(UserWarning, match="Extrapolating beyond simulated redshift range"): eos.delta_squared(50, np.array([0.1])) @@ -31,14 +25,10 @@ def test_legacy(): with pytest.warns(UserWarning, match="Theory power corresponds to z=9.5, not z"): theory.delta_squared(1.0, 1.0) - with pytest.warns( - UserWarning, match="Extrapolating above the simulated theoretical k" - ): + with pytest.warns(UserWarning, match="Extrapolating above the simulated theoretical k"): theory.delta_squared(9.5, np.array([0.1, 1e6])) - with pytest.warns( - UserWarning, match="Extrapolating below the simulated theoretical k" - ): + with pytest.warns(UserWarning, match="Extrapolating below the simulated theoretical k"): theory.delta_squared(9.5, np.array([0.0001, 0.1])) diff --git a/tests/test_types.py b/tests/test_types.py index 856dcf9..0ff9fd1 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,8 +1,8 @@ -import pytest +"""Test the units module.""" import attr +import pytest from astropy import units as u - from py21cmsense import units as tp diff --git a/tests/test_uvw.py b/tests/test_uvw.py index 038b78a..a61b2ac 100644 --- a/tests/test_uvw.py +++ b/tests/test_uvw.py @@ -1,14 +1,12 @@ """Tests of the phasing code for calculating UVWs.""" -import pytest - import numpy as np +import pytest from astropy import units as un from astropy.coordinates import EarthLocation, SkyCoord from astropy.time import Time -from pyuvdata import utils as uvutils - from py21cmsense._utils import phase_past_zenith +from pyuvdata import utils as uvutils @pytest.mark.parametrize("lat", [-1.0, -0.5, 0, 0.5, 1.0]) @@ -68,9 +66,7 @@ def test_phase_past_zenith_shape(): times = np.array([0, 0.1, 0, 0.1]) * un.day # Almost rotated to the horizon. - uvws = phase_past_zenith( - time_past_zenith=times, bls_enu=bls_enu, latitude=0 * un.rad - ) + uvws = phase_past_zenith(time_past_zenith=times, bls_enu=bls_enu, latitude=0 * un.rad) assert uvws.shape == (5, 4, 3) assert np.allclose(uvws[0], uvws[2]) # Same baselines @@ -90,9 +86,7 @@ def test_use_apparent(lat): times = np.linspace(-1, 1, 3) * un.hour # Almost rotated to the horizon. - uvws = phase_past_zenith( - time_past_zenith=times, bls_enu=bls_enu, latitude=lat * un.rad - ) + uvws = phase_past_zenith(time_past_zenith=times, bls_enu=bls_enu, latitude=lat * un.rad) uvws0 = phase_past_zenith( time_past_zenith=times, bls_enu=bls_enu, diff --git a/tests/test_yaml.py b/tests/test_yaml.py index 087c62b..548cfae 100644 --- a/tests/test_yaml.py +++ b/tests/test_yaml.py @@ -1,15 +1,16 @@ -import pytest +"""Test the yaml module.""" -import numpy as np import pickle + +import numpy as np +import pytest from astropy import units as un from astropy.io.misc import yaml - from py21cmsense.yaml import LoadError def test_file_not_found(): - with pytest.raises(IOError): + with pytest.raises(IOError, match="not found"): yaml.load("!txt non-existent-file.txt")