Skip to content

Commit

Permalink
wip - first working version of major refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeitsperre committed Aug 20, 2024
1 parent 8d14fa4 commit e4609a0
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 461 deletions.
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies = [
"packaging >=24.0",
"pandas >=2.2",
"pint >=0.18",
"platformdirs >=3.2",
"pooch >=1.8.0",
"pyarrow >=15.0.0", # Strongly encouraged for pandas v2.2.0+
"pyyaml >=6.0.1",
"scikit-learn >=0.21.3",
Expand Down Expand Up @@ -79,8 +79,6 @@ dev = [
"nbval >=0.11.0",
"pandas-stubs >=2.2",
"pip >=24.0",
"platformdirs >=3.2",
"pooch >=1.8.0",
"pre-commit >=3.7",
"pylint >=3.2.4",
"pytest >=8.0.0",
Expand Down
117 changes: 54 additions & 63 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from xclim.core import indicator
from xclim.core.calendar import max_doy
from xclim.testing import helpers
from xclim.testing.helpers import default_cache_dir # noqa
from xclim.testing.helpers import nimbus as _nimbus
from xclim.testing.helpers import open_dataset as _open_dataset
from xclim.testing.helpers import test_timeseries
from xclim.testing.utils import default_cache_dir # noqa
from xclim.testing.utils import open_dataset as _open_dataset


@pytest.fixture
Expand All @@ -26,21 +26,7 @@ def random() -> np.random.Generator:

@pytest.fixture
def tmp_netcdf_filename(tmpdir):
yield Path(tmpdir).joinpath("testfile.nc")


@pytest.fixture(autouse=True, scope="session")
def threadsafe_data_dir(tmp_path_factory):
yield Path(tmp_path_factory.getbasetemp().joinpath("data"))


@pytest.fixture(autouse=True, scope="session")
def nimbus(threadsafe_data_dir):
yield _nimbus(
data_dir=threadsafe_data_dir,
repo=helpers.TESTDATA_REPO_URL,
branch=helpers.TESTDATA_BRANCH,
)
return Path(tmpdir).joinpath("testfile.nc")


@pytest.fixture
Expand All @@ -57,6 +43,11 @@ def _lat_series(values):
return _lat_series


@pytest.fixture
def timeseries():
return test_timeseries


@pytest.fixture
def tas_series():
"""Return mean temperature time series."""
Expand Down Expand Up @@ -309,40 +300,30 @@ def rlus_series():


@pytest.fixture(scope="session")
def cmip3_day_tas(threadsafe_data_dir):
# xr.set_options(enable_cftimeindex=False)
ds = _open_dataset(
"cmip3/tas.sresb1.giss_model_e_r.run1.atm.da.nc",
cache_dir=threadsafe_data_dir,
branch=helpers.TESTDATA_BRANCH,
engine="h5netcdf",
)
yield ds.tas
ds.close()
def threadsafe_data_dir(tmp_path_factory):
return Path(tmp_path_factory.getbasetemp().joinpath("data"))


@pytest.fixture(scope="session")
def get_file(nimbus):
def _get_session_scoped_file(file: str):
nimbus.fetch(file)

return _get_session_scoped_file
def nimbus(threadsafe_data_dir):
return _nimbus(
data_dir=threadsafe_data_dir,
repo=helpers.TESTDATA_REPO_URL,
branch=helpers.TESTDATA_BRANCH,
)


@pytest.fixture(scope="session")
def open_dataset(threadsafe_data_dir):
def _open_session_scoped_file(
file: str | os.PathLike, branch: str = helpers.TESTDATA_BRANCH, **xr_kwargs
):
def open_dataset(nimbus):
def _open_session_scoped_file(file: str | os.PathLike, **xr_kwargs):
xr_kwargs.setdefault("cache", True)
xr_kwargs.setdefault("engine", "h5netcdf")
return _open_dataset(
file, cache_dir=threadsafe_data_dir, branch=branch, **xr_kwargs
)
return _open_dataset(file, cache_dir=nimbus.path, **xr_kwargs)

return _open_session_scoped_file


@pytest.fixture
@pytest.fixture(scope="session")
def official_indicators():
# Remove unofficial indicators (as those created during the tests, and those from YAML-built modules)
registry_cp = indicator.registry.copy()
Expand All @@ -352,17 +333,39 @@ def official_indicators():
return registry_cp


@pytest.fixture(scope="function")
def atmosds(threadsafe_data_dir) -> xr.Dataset:
@pytest.fixture
def lafferty_sriver_ds(nimbus) -> xr.Dataset:
"""Get data from Lafferty & Sriver unit test.
Notes
-----
https://github.com/david0811/lafferty-sriver_2023_npjCliAtm/tree/main/unit_test
"""
fn = nimbus.fetch(
"uncertainty_partitioning/seattle_avg_tas.csv",
)

df = pd.read_csv(fn, parse_dates=["time"]).rename(
columns={"ssp": "scenario", "ensemble": "downscaling"}
)

# Make xarray dataset
return xr.Dataset.from_dataframe(
df.set_index(["scenario", "model", "downscaling", "time"])
)


@pytest.fixture
def atmosds(nimbus) -> xr.Dataset:
"""Get synthetic atmospheric dataset."""
return _open_dataset(
threadsafe_data_dir.joinpath("atmosds.nc"),
cache_dir=threadsafe_data_dir,
branch=helpers.TESTDATA_BRANCH,
"atmosds.nc",
cache_dir=nimbus.path,
engine="h5netcdf",
).load()


@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def ensemble_dataset_objects() -> dict[str, str]:
edo = dict()
edo["nc_files_simple"] = [
Expand All @@ -378,8 +381,8 @@ def ensemble_dataset_objects() -> dict[str, str]:
return edo


@pytest.fixture(scope="session", autouse=True)
def gather_session_data(threadsafe_data_dir, worker_id):
@pytest.fixture(autouse=True, scope="session")
def gather_session_data(request, nimbus, worker_id):
"""Gather testing data on pytest run.
When running pytest with multiple workers, one worker will copy data remotely to _default_cache_dir while
Expand All @@ -389,25 +392,13 @@ def gather_session_data(threadsafe_data_dir, worker_id):
Additionally, this fixture is also used to generate the `atmosds` synthetic testing dataset.
"""
helpers.testing_setup_warnings()
helpers.gather_testing_data(threadsafe_data_dir, worker_id)
helpers.generate_atmos(threadsafe_data_dir)


@pytest.fixture(scope="session", autouse=True)
def cleanup(request):
"""Cleanup a testing file once we are finished.
This flag prevents remote data from being downloaded multiple times in the same pytest run.
"""
helpers.gather_testing_data(nimbus.path, worker_id)
helpers.generate_atmos(nimbus.path)

def remove_data_written_flag():
"""Cleanup cache folder once we are finished."""
flag = default_cache_dir.joinpath(".data_written")
if flag.exists():
flag.unlink()

request.addfinalizer(remove_data_written_flag)


@pytest.fixture
def timeseries():
return test_timeseries
10 changes: 5 additions & 5 deletions tests/test_analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def test_exact_randn(exact_randn):
@pytest.mark.slow
@pytest.mark.parametrize("method", xca.metrics.keys())
def test_spatial_analogs(method, open_dataset):
diss = open_dataset("SpatialAnalogs/dissimilarity")
data = open_dataset("SpatialAnalogs/indicators")
diss = open_dataset("SpatialAnalogs/dissimilarity.nc")
data = open_dataset("SpatialAnalogs/indicators.nc")

target = data.sel(lat=46.1875, lon=-72.1875, time=slice("1970", "1990"))
candidates = data.sel(time=slice("1970", "1990"))
Expand All @@ -75,7 +75,7 @@ def test_spatial_analogs(method, open_dataset):
def test_unsupported_spatial_analog_method(open_dataset):
method = "KonMari"

data = open_dataset("SpatialAnalogs/indicators")
data = open_dataset("SpatialAnalogs/indicators.nc")
target = data.sel(lat=46.1875, lon=-72.1875, time=slice("1970", "1990"))
candidates = data.sel(time=slice("1970", "1990"))

Expand All @@ -87,8 +87,8 @@ def test_unsupported_spatial_analog_method(open_dataset):

def test_spatial_analogs_multi_index(open_dataset):
# Test multi-indexes
diss = open_dataset("SpatialAnalogs/dissimilarity")
data = open_dataset("SpatialAnalogs/indicators")
diss = open_dataset("SpatialAnalogs/dissimilarity.nc")
data = open_dataset("SpatialAnalogs/indicators.nc")

target = data.sel(lat=46.1875, lon=-72.1875, time=slice("1970", "1990"))
candidates = data.sel(time=slice("1970", "1990"))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_atmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_humidex(tas_series):


def test_heat_index(atmosds):
# Keep just Montreal values for summertime as we need tas > 20 degC
# Keep just Montreal values for summer as we need tas > 20 degC
tas = atmosds.tasmax[1][150:170]
hurs = atmosds.hurs[1][150:170]

Expand Down
8 changes: 5 additions & 3 deletions tests/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -2562,12 +2562,14 @@ def test_simple(self, open_dataset, ind, exp):
out = ind(ds.tas.sel(location="Victoria"))
np.testing.assert_almost_equal(out[0], exp, decimal=4)

def test_indice_against_icclim(self, cmip3_day_tas):
def test_indice_against_icclim(self, open_dataset):
from xclim.indicators import icclim # noqa

cmip3_tas = open_dataset("cmip3/tas.sresb1.giss_model_e_r.run1.atm.da.nc").tas

with set_options(cf_compliance="log"):
ind = xci.tg_mean(cmip3_day_tas)
icclim = icclim.TG(cmip3_day_tas)
ind = xci.tg_mean(cmip3_tas)
icclim = icclim.TG(cmip3_tas)

np.testing.assert_array_equal(icclim, ind)

Expand Down
16 changes: 2 additions & 14 deletions tests/test_partitioning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import numpy as np
import pandas as pd
import xarray as xr

from xclim.ensembles import fractional_uncertainty, hawkins_sutton, lafferty_sriver
Expand Down Expand Up @@ -108,19 +107,8 @@ def test_lafferty_sriver_synthetic(random):
lafferty_sriver(da, sm=sm)


def test_lafferty_sriver(get_file):
seattle = get_file("uncertainty_partitioning/seattle_avg_tas.csv")

df = pd.read_csv(seattle, parse_dates=["time"]).rename(
columns={"ssp": "scenario", "ensemble": "downscaling"}
)

# Make xarray dataset
ds = xr.Dataset.from_dataframe(
df.set_index(["scenario", "model", "downscaling", "time"])
)

_g, u = lafferty_sriver(ds.tas)
def test_lafferty_sriver(lafferty_sriver_ds):
_g, u = lafferty_sriver(lafferty_sriver_ds.tas)

fu = fractional_uncertainty(u)

Expand Down
68 changes: 3 additions & 65 deletions tests/test_testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import platform
import sys
from pathlib import Path
from urllib.error import URLError

import numpy as np
import pytest
from xarray import Dataset

import xclim.testing.utils as utilities
from xclim import __version__ as __xclim_version__
from xclim.testing import helpers
from xclim.testing import utils as utilities
from xclim.testing.helpers import test_timeseries as timeseries


Expand Down Expand Up @@ -39,52 +39,9 @@ def file_md5_checksum(f_name):
hash_md5.update(f.read())
return hash_md5.hexdigest()

@pytest.mark.requires_internet
def test_get_failure(self, tmp_path):
bad_repo_address = "https://github.com/beard/of/zeus/"
with pytest.raises(FileNotFoundError):
utilities._get(
Path("san_diego", "60_percent_of_the_time_it_works_everytime"),
bad_repo_address,
"main",
tmp_path,
)

@pytest.mark.requires_internet
def test_open_dataset_with_bad_file(self, tmp_path):
cmip3_folder = tmp_path.joinpath("main", "cmip3")
cmip3_folder.mkdir(parents=True)

cmip3_file = "tas.sresb1.giss_model_e_r.run1.atm.da.nc"
Path(cmip3_folder, cmip3_file).write_text("This file definitely isn't right.")

cmip3_md5 = f"{cmip3_file}.md5"
bad_cmip3_md5 = "bc51206e6462fc8ed08fd4926181274c"
Path(cmip3_folder, cmip3_md5).write_text(bad_cmip3_md5)

# Check for raised warning for local file md5 sum and remote md5 sum
with pytest.warns(UserWarning):
new_cmip3_file = utilities._get(
Path("cmip3", cmip3_file),
github_url="https://github.com/Ouranosinc/xclim-testdata",
branch="main",
cache_dir=tmp_path,
)

# Ensure that the new cmip3 file is in the cache directory
assert (
self.file_md5_checksum(Path(cmip3_folder, new_cmip3_file)) != bad_cmip3_md5
)

# Ensure that the md5 file was updated at the same time
assert (
self.file_md5_checksum(Path(cmip3_folder, new_cmip3_file))
== Path(cmip3_folder, cmip3_md5).read_text()
)

@pytest.mark.requires_internet
def test_open_testdata(self):
ds = utilities.open_dataset(
ds = helpers.open_dataset(
Path("cmip5/tas_Amon_CanESM2_rcp85_r1i1p1_200701-200712"), engine="h5netcdf"
)
assert ds.lon.size == 128
Expand Down Expand Up @@ -126,22 +83,3 @@ def test_release_notes_file_not_implemented(self, tmp_path):
temp_filename = tmp_path.joinpath("version_info.txt")
with pytest.raises(NotImplementedError):
utilities.publish_release_notes(style="qq", file=temp_filename)


class TestTestingFileAccessors:
def test_unsafe_urls(self):
with pytest.raises(
ValueError, match="GitHub URL not secure: 'ftp://domain.does.not.exist/'."
):
utilities.open_dataset(
"doesnt_exist.nc", github_url="ftp://domain.does.not.exist/"
)

def test_malicious_urls(self):
with pytest.raises(
URLError,
match="urlopen error OPeNDAP URL is not well-formed: 'doesnt_exist.nc'",
):
utilities.open_dataset(
"doesnt_exist.nc", dap_url="Robert'); DROP TABLE STUDENTS; --"
)
Loading

0 comments on commit e4609a0

Please sign in to comment.