Skip to content

Commit

Permalink
start importing some things from test_grdata.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mabruzzo committed Sep 22, 2024
1 parent 8f78ac6 commit 99a770e
Showing 1 changed file with 17 additions and 57 deletions.
74 changes: 17 additions & 57 deletions src/python/tests/test_auto_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import io
import os
import shutil
import sys

import numpy as np
import pytest
Expand All @@ -30,63 +29,22 @@
from pygrackle.utilities.physical_constants import sec_per_Myr
from pygrackle.utilities.testing import assert_allequal_arraydict, ensure_dir

from test_grdata import (
_ENV_VARS, # holds list of environment variables that affect data dir location
modified_env,
)
from test_query_units import _setup_generic_chemistry_data


# we probably don't have to skip everything
if not hasattr(os, "putenv"):
pytest.skip(
"several tests need os.putenv to work properly", allow_module_level=True
)

# _ENV_VAR holds the list of environment variables that could affect the
# location of the data directory
if sys.platform.startswith("darwin"):
_ENV_VARS = ("HOME", "GRACKLE_DATA_DIR")
else:
_ENV_VARS = ("HOME", "GRACKLE_DATA_DIR", "XDG_DATA_HOME")


def _ensure_removed(d, key):
try:
del d[key]
except KeyError:
pass


@contextlib.contextmanager
def modified_env(new_env_vals, extra_cleared_variables=None):
"""
Temporarily overwrite the environment variables. This is necessary to test C
extensions that rely upon the environment variables
"""
if extra_cleared_variables is None:
extra_cleared_variables = None

# record the original values for any variable we will overwrite
original_vals = {}
try:
for var in filter(lambda e: e not in new_env_vals, extra_cleared_variables):
original_vals[var] = os.environ.get(var, None)
_ensure_removed(os.environ, var)

for var, new_val in new_env_vals.items():
original_vals[var] = os.environ.get(var, None)
if new_val is None:
_ensure_removed(os.environ, var)
else:
os.environ[var] = new_val

yield

finally:
# restore to the initial values
for var, val in original_vals.items():
if val is None:
_ensure_removed(os.environ, var)
else:
os.environ[var] = val


# it would be nice to replace the following with test_grdata.CLIApp, but that would
# definitely take some work
class DataFileManagementHarness:
"""
This is a wrapper around the cli interface provided by pygrackle.
Expand Down Expand Up @@ -247,6 +205,7 @@ def managed_datafile(request, tmp_path):
) as full_path:
yield full_path


def setup_generic_problem(parameter_overrides={}):
"""set up a really simplistic problem"""
chem = _setup_generic_chemistry_data(
Expand All @@ -267,9 +226,11 @@ def setup_generic_problem(parameter_overrides={}):

@pytest.mark.parametrize(
"managed_datafile",
([pytest.param(None, id = "default-datadir")] +
[pytest.param(var, id=f"arbitrary-{var}") for var in _ENV_VARS]),
indirect=True
(
[pytest.param(None, id="default-datadir")]
+ [pytest.param(var, id=f"arbitrary-{var}") for var in _ENV_VARS]
),
indirect=True,
)
def test_autofile_equivalence(managed_datafile):
"""
Expand Down Expand Up @@ -304,7 +265,7 @@ def test_autofile_equivalence(managed_datafile):
fc_other, _ = setup_generic_problem(
parameter_overrides={
"grackle_data_file": fname,
"grackle_data_file_options": constants.GR_DFOPT_MANAGED
"grackle_data_file_options": constants.GR_DFOPT_MANAGED,
}
)
fc_other.solve_chemistry(dt)
Expand All @@ -319,7 +280,7 @@ def test_autofile_fail_unknown_file():
skip_initialize=True,
parameter_overrides={
"grackle_data_file": "not-a-file.png",
"grackle_data_file_options": constants.GR_DFOPT_MANAGED
"grackle_data_file_options": constants.GR_DFOPT_MANAGED,
},
)
assert chem.initialize() == constants.GR_FAIL
Expand All @@ -345,7 +306,7 @@ def test_autofile_fail_known_missing_file(tmp_path):
skip_initialize=True,
parameter_overrides={
"grackle_data_file": alt_fname,
"grackle_data_file_options": constants.GR_DFOPT_MANAGED
"grackle_data_file_options": constants.GR_DFOPT_MANAGED,
},
)
assert chem.initialize() == constants.GR_FAIL
Expand Down Expand Up @@ -374,8 +335,7 @@ def test_autofile_fail_bad_checksum(tmp_path):
skip_initialize=True,
parameter_overrides={
"grackle_data_file": alt_fname,
"grackle_data_file_options": constants.GR_DFOPT_MANAGED
"grackle_data_file_options": constants.GR_DFOPT_MANAGED,
},
)
assert chem.initialize() == constants.GR_FAIL

0 comments on commit 99a770e

Please sign in to comment.