Skip to content

Commit

Permalink
Refactor to parse version spec
Browse files Browse the repository at this point in the history
Simplify one more time down to a single version comparison helper that
parses a package spec string and compares that spec to the currently
installed package.
  • Loading branch information
dcamron committed Nov 10, 2023
1 parent e7deffb commit b0c9471
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 36 deletions.
56 changes: 33 additions & 23 deletions src/metpy/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import contextlib
import functools
import importlib
import re

import numpy as np
import numpy.testing
Expand All @@ -24,40 +25,49 @@
from .units import units


def module_version_before(modname, ver):
"""Return whether the active module is before a certain version.
def module_version_check(version_spec):
"""Return comparison between the active module and a requested version number.
Parameters
----------
modname : str
The module name to import
ver : str
The version string for a certain release
version_spec : str
Module version specification to validate against installed package. Must take the form
of `f'{module_name}{comparison_operator}{version_number}'` where `comparison_operator`
must be one of `['==', '=', '!=', '<', '<=', '>', '>=']`.
Returns
-------
bool : whether the current version was released before the passed in one
bool : Whether the installed package validates against the provided specification
"""
module = importlib.import_module(modname)
return Version(module.__version__) < Version(ver)
comparison_operators = {
'==': lambda x, y: x == y, '=': lambda x, y: x == y, '!=': lambda x, y: x != y,
'<': lambda x, y: x < y, '<=': lambda x, y: x <= y,
'>': lambda x, y: x > y, '>=': lambda x, y: x >= y,
}

# Match version_spec for groups of module name,
# comparison operator, and requested module version
pattern = re.compile(r'(\w+)\s*([<>!=]+)\s*([\d.]+)')
match = pattern.match(version_spec)

def module_version_equal(modname, ver):
"""Return whether the active module is equal to a certain version.
if match:
module_name = match.group(1)
comparison = match.group(2)
version_number = match.group(3)
else:
raise ValueError('No valid version specification string matched.')

Parameters
----------
modname : str
The module name to import
ver : str
The version string for a certain release
module = importlib.import_module(module_name)

Returns
-------
bool : whether the current version is equal to the passed in one
"""
module = importlib.import_module(modname)
return Version(module.__version__) == Version(ver)
installed_version = Version(module.__version__)
specified_version = Version(version_number)

try:
return comparison_operators[comparison](installed_version, specified_version)
except KeyError:
raise ValueError(
"Comparison operator not one of ['==', '=', '!=', '<', '<=', '>', '>=']."
) from None


def needs_module(module):
Expand Down
4 changes: 2 additions & 2 deletions tests/calc/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
mean_pressure_weighted, precipitable_water, significant_tornado,
supercell_composite, weighted_continuous_average)
from metpy.testing import (assert_almost_equal, assert_array_almost_equal, get_upper_air_data,
module_version_before)
module_version_check)
from metpy.units import concatenate, units


Expand Down Expand Up @@ -131,7 +131,7 @@ def test_weighted_continuous_average():
assert_almost_equal(v, 6.900543760612305 * units('m/s'), 7)


@pytest.mark.xfail(condition=module_version_before('pint', '0.21'), reason='hgrecco/pint#1593')
@pytest.mark.xfail(condition=module_version_check('pint<0.21'), reason='hgrecco/pint#1593')
def test_weighted_continuous_average_temperature():
"""Test pressure-weighted mean temperature function with vertical interpolation."""
data = get_upper_air_data(datetime(2016, 5, 22, 0), 'DDC')
Expand Down
6 changes: 3 additions & 3 deletions tests/plots/test_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from metpy.io.metar import parse_metar_file
from metpy.plots import (ArrowPlot, BarbPlot, ContourPlot, FilledContourPlot, ImagePlot,
MapPanel, PanelContainer, PlotGeometry, PlotObs, RasterPlot)
from metpy.testing import module_version_before, needs_cartopy
from metpy.testing import module_version_check, needs_cartopy
from metpy.units import units


Expand Down Expand Up @@ -336,7 +336,7 @@ def test_declarative_contour_cam():

@pytest.mark.mpl_image_compare(
remove_text=True,
tolerance=3.71 if module_version_before('matplotlib', '3.8') else 0.74)
tolerance=3.71 if module_version_check('matplotlib<3.8') else 0.74)
@needs_cartopy
def test_declarative_contour_options():
"""Test making a contour plot."""
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_declarative_additional_layers_plot_options():

@pytest.mark.mpl_image_compare(
remove_text=True,
tolerance=2.74 if module_version_before('matplotlib', '3.8') else 1.91)
tolerance=2.74 if module_version_check('matplotlib<3.8') else 1.91)
@needs_cartopy
def test_declarative_contour_convert_units():
"""Test making a contour plot."""
Expand Down
8 changes: 4 additions & 4 deletions tests/plots/test_skewt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest

from metpy.plots import Hodograph, SkewT
from metpy.testing import module_version_equal
from metpy.testing import module_version_check
from metpy.units import units


Expand Down Expand Up @@ -156,9 +156,9 @@ def test_skewt_units():

# On Matplotlib <= 3.6, ax[hv]line() doesn't trigger unit labels
assert skew.ax.get_xlabel() == (
'degree_Celsius' if module_version_equal('matplotlib', '3.7.0') else '')
'degree_Celsius' if module_version_check('matplotlib==3.7.0') else '')
assert skew.ax.get_ylabel() == (
'hectopascal' if module_version_equal('matplotlib', '3.7.0') else '')
'hectopascal' if module_version_check('matplotlib==3.7.0') else '')

# Clear them for the image test
skew.ax.set_xlabel('')
Expand Down Expand Up @@ -321,7 +321,7 @@ def test_hodograph_api():


@pytest.mark.mpl_image_compare(
remove_text=True, tolerance=0.6 if module_version_equal('matplotlib', '3.5') else 0.)
remove_text=True, tolerance=0.6 if module_version_check('matplotlib==3.5') else 0.)
def test_hodograph_units():
"""Test passing quantities to Hodograph."""
fig = plt.figure(figsize=(9, 9))
Expand Down
6 changes: 3 additions & 3 deletions tests/plots/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import xarray as xr

from metpy.plots import add_metpy_logo, add_timestamp, add_unidata_logo, convert_gempak_color
from metpy.testing import get_test_data, module_version_before
from metpy.testing import get_test_data, module_version_check


@pytest.mark.mpl_image_compare(tolerance=2.638, remove_text=True)
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_add_logo_invalid_size():


@pytest.mark.mpl_image_compare(
tolerance=1.072 if module_version_before('matplotlib', '3.5') else 0,
tolerance=1.072 if module_version_check('matplotlib<3.5') else 0,
remove_text=True)
def test_gempak_color_image_compare():
"""Test creating a plot with all the GEMPAK colors."""
Expand All @@ -113,7 +113,7 @@ def test_gempak_color_image_compare():


@pytest.mark.mpl_image_compare(
tolerance=1.215 if module_version_before('matplotlib', '3.5') else 0,
tolerance=1.215 if module_version_check('matplotlib<3.5') else 0,
remove_text=True)
def test_gempak_color_xw_image_compare():
"""Test creating a plot with all the GEMPAK colors using xw style."""
Expand Down
21 changes: 20 additions & 1 deletion tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from metpy.deprecation import MetpyDeprecationWarning
from metpy.testing import (assert_array_almost_equal, check_and_drop_units,
check_and_silence_deprecation)
check_and_silence_deprecation, module_version_check)


# Test #1183: numpy.testing.assert_array* ignores any masked value, so work-around
Expand Down Expand Up @@ -42,3 +42,22 @@ def test_check_and_drop_units_with_dataarray():
assert isinstance(actual, np.ndarray)
assert isinstance(desired, np.ndarray)
np.testing.assert_array_almost_equal(actual, desired)


def test_module_version_check():
"""Test parsing and version comparison of installed package."""
assert module_version_check('numpy>0.0.0')
assert module_version_check('numpy >= 0.0')
assert module_version_check('numpy!=0')


def test_module_version_check_nonsense():
"""Test failed pattern match of package specification."""
with pytest.raises(ValueError, match='No valid version '):
module_version_check('thousands of birds picking packages')


def test_module_version_check_invalid_comparison():
"""Test invalid operator in version comparison."""
with pytest.raises(ValueError, match='Comparison operator not '):
module_version_check('numpy<<36')

0 comments on commit b0c9471

Please sign in to comment.