Skip to content

Commit

Permalink
Merge branch 'master' into v4-prep
Browse files Browse the repository at this point in the history
Conflicts:
	src/py21cmfast/src/ps.c
  • Loading branch information
daviesje committed Sep 11, 2024
2 parents 4938f5c + 16db064 commit 44b85e3
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 17 deletions.
11 changes: 5 additions & 6 deletions src/py21cmfast/cache_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def readbox(
cls = getattr(outputs, kind)

if hasattr(cls, "from_file"):
inst = cls.from_file(fname, direc=direc, load_data=load_data)
inst = cls.from_file(
fname, direc=direc, load_data=load_data
) # for OutputStruct
else:
inst = cls.read(fname, direc=direc)
inst = cls.read(fname, direc=direc) # for HighlevelOutputStruct

return inst

Expand Down Expand Up @@ -236,12 +238,9 @@ def clear_cache(**kwargs):
kwargs :
All options passed through to :func:`query_cache`.
"""
if "show" not in kwargs:
kwargs["show"] = False

direc = kwargs.get("direc", path.expanduser(config["direc"]))
number = 0
for fname, _ in query_cache(**kwargs):
for fname in list_datasets(**kwargs):
if kwargs.get("show", True):
logger.info(f"Removing {fname}")
os.remove(path.join(direc, fname))
Expand Down
6 changes: 5 additions & 1 deletion src/py21cmfast/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,11 @@ def use(self, **kwargs):

for k, val in kwargs.items():
if k.upper() not in this_attr_upper:
raise ValueError(f"{k} is not a valid parameter of global_params")
warnings.warn(
f"{k} is not a valid parameter of global_params, and will be ignored",
UserWarning,
)
continue
key = this_attr_upper[k.upper()]
prev[key] = getattr(self, key)
setattr(self, key, val)
Expand Down
63 changes: 55 additions & 8 deletions src/py21cmfast/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,10 +1129,33 @@ def save(self, fname=None, direc=".", clobber: bool = False):
return self._write(direc=direc, fname=fname, clobber=clobber)

@classmethod
def _read_inputs(cls, fname):
def _read_inputs(cls, fname, safe=True):
kwargs = {}
with h5py.File(fname, "r") as fl:
glbls = dict(fl["_globals"].attrs)
if glbls.keys() != global_params.keys():
missing_items = [
(k, v) for k, v in global_params.items() if k not in glbls.keys()
]
extra_items = [
(k, v) for k, v in glbls.items() if k not in global_params.keys()
]
message = (
f"There are extra or missing global params in the file to be read.\n"
f"EXTRAS: {extra_items}\n"
f"MISSING: {missing_items}\n"
)
# we don't save None values (we probably should) or paths so ignore these
# We also only print the warning for these fields if "safe" is turned off
if safe and any(
v is not None for k, v in missing_items if "path" not in k
):
raise ValueError(message)
else:
warnings.warn(
message
+ "\nExtras are ignored and missing are set to default (shown) values"
)
kwargs["redshift"] = float(fl.attrs["redshift"])

if "photon_nonconservation_data" in fl.keys():
Expand All @@ -1142,8 +1165,8 @@ def _read_inputs(cls, fname):
return kwargs, glbls

@classmethod
def read(cls, fname, direc="."):
"""Read a lightcone file from disk, creating a LightCone object.
def read(cls, fname, direc=".", safe=True):
"""Read the HighLevelOutput file from disk, creating a LightCone or Coeval object.
Parameters
----------
Expand All @@ -1153,6 +1176,10 @@ def read(cls, fname, direc="."):
If fname, is relative, the directory in which to find the file. By default,
both the current directory and default cache and the will be searched, in
that order.
safe : bool
If safe is true, we throw an error if the parameter structures in the file do not
match the structures in the `inputs.py` module. If false, we allow extra and missing
items, setting the missing items to the default values and ignoring extra items.
Returns
-------
Expand All @@ -1165,7 +1192,7 @@ def read(cls, fname, direc="."):
if not os.path.exists(fname):
raise FileExistsError(f"The file {fname} does not exist!")

park, glbls = cls._read_inputs(fname)
park, glbls = cls._read_inputs(fname, safe=safe)
boxk = cls._read_particular(fname)

with global_params.use(**glbls):
Expand Down Expand Up @@ -1480,7 +1507,7 @@ def make_checkpoint(self, fname, index: int, redshift: float):
self._current_index = index

@classmethod
def _read_inputs(cls, fname):
def _read_inputs(cls, fname, safe=True):
kwargs = {}
with h5py.File(fname, "r") as fl:
for k, kls in [
Expand All @@ -1489,14 +1516,34 @@ def _read_inputs(cls, fname):
("flag_options", FlagOptions),
("astro_params", AstroParams),
]:
grp = fl[k]
kwargs[k] = kls(dict(grp.attrs))
dct = dict(fl[k].attrs)
if kls._defaults_.keys() != dct.keys():
missing_items = [
(k, v) for k, v in kls._defaults_.items() if k not in dct.keys()
]
extra_items = [
(k, v) for k, v in dct.items() if k not in kls._defaults_.keys()
]
message = (
f"There are extra or missing {kls} in the file to be read.\n"
f"EXTRAS: {extra_items}\n"
f"MISSING: {missing_items}\n"
)
if safe and any(v is not None for k, v in missing_items):
raise ValueError(message)
else:
warnings.warn(
message
+ "\nExtras are ignored and missing are set to default (shown) values."
+ "\nUsing these parameter structures in further computation will give inconsistent results."
)
kwargs[k] = kls(dct)
kwargs["random_seed"] = int(fl.attrs["random_seed"])
kwargs["current_redshift"] = fl.attrs.get("current_redshift", None)
kwargs["current_index"] = fl.attrs.get("current_index", None)

# Get the standard inputs.
kw, glbls = _HighLevelOutput._read_inputs(fname)
kw, glbls = _HighLevelOutput._read_inputs(fname, safe=safe)
return {**kw, **kwargs}, glbls

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions src/py21cmfast/src/cosmology.c
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ double TF_CLASS(double k, int flag_int, int flag_dv)
if (k > kclass[CLASS_LENGTH-1]) { // k>kmax
LOG_WARNING("Called TF_CLASS with k=%f, larger than kmax! Returning value at kmax.", k);
if(flag_dv == 0){ // output is density
return (Tmclass[CLASS_LENGTH]/kclass[CLASS_LENGTH-1]/kclass[CLASS_LENGTH-1]);
return (Tmclass[CLASS_LENGTH-1]/kclass[CLASS_LENGTH-1]/kclass[CLASS_LENGTH-1]);
}
else if(flag_dv == 1){ // output is rel velocity
return (Tvclass_vcb[CLASS_LENGTH]/kclass[CLASS_LENGTH-1]/kclass[CLASS_LENGTH-1]);
return (Tvclass_vcb[CLASS_LENGTH-1]/kclass[CLASS_LENGTH-1]/kclass[CLASS_LENGTH-1]);
} //we just set it to the last value, since sometimes it wants large k for R<<cell_size, which does not matter much.
else{
LOG_ERROR("Invalid flag_dv %d passed to TF_CLASS",flag_dv);
Expand Down
86 changes: 86 additions & 0 deletions tests/test_high_level_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
InitialConditions,
LightCone,
TsBox,
UserParams,
global_params,
run_coeval,
run_lightcone,
Expand Down Expand Up @@ -56,6 +57,91 @@ def ang_lightcone(ic, lc):
)


def test_read_bad_file_lc(test_direc, lc):
# create a bad hdf5 file with some good fields,
# some bad fields, and some missing fields
# in both input parameters and box structures
fname = lc.save(direc=test_direc)
with h5py.File(fname, "r+") as f:
# make gluts, these should be ignored on reading
f["user_params"].attrs["NotARealParameter"] = "fake_param"
f["_globals"].attrs["NotARealGlobal"] = "fake_param"

# make gaps
del f["user_params"].attrs["BOX_LEN"]
del f["_globals"].attrs["OPTIMIZE_MIN_MASS"]

# load without compatibility mode, make sure we throw the right error
with pytest.raises(ValueError, match="There are extra or missing"):
LightCone.read(fname, direc=test_direc, safe=True)

# load in compatibility mode, check that we warn correctly
with pytest.warns(UserWarning, match="There are extra or missing"):
lc2 = LightCone.read(fname, direc=test_direc, safe=False)

# check that the fake fields didn't show up in the struct
assert not hasattr(lc2.user_params, "NotARealParameter")
assert "NotARealGlobal" not in lc2.global_params.keys()

# check that missing fields are set to default
assert lc2.user_params.BOX_LEN == UserParams._defaults_["BOX_LEN"]
assert lc2.global_params["OPTIMIZE_MIN_MASS"] == global_params.OPTIMIZE_MIN_MASS

# check that the fields which are good are read in the struct
assert all(
getattr(lc2.user_params, k) == getattr(lc.user_params, k)
for k in UserParams._defaults_.keys()
if k != "BOX_LEN"
)
assert all(
lc2.global_params[k] == lc.global_params[k]
for k in global_params.keys()
if k != "OPTIMIZE_MIN_MASS"
)


def test_read_bad_file_coev(test_direc, coeval):
# create a bad hdf5 file with some good fields,
# some bad fields, and some missing fields
# in both input parameters and box structures
fname = coeval.save(direc=test_direc)
with h5py.File(fname, "r+") as f:
# make gluts, these should be ignored on reading
f["user_params"].attrs["NotARealParameter"] = "fake_param"
f["_globals"].attrs["NotARealGlobal"] = "fake_param"

# make gaps
del f["user_params"].attrs["BOX_LEN"]
del f["_globals"].attrs["OPTIMIZE_MIN_MASS"]

# load in the coeval check that we warn correctly
with pytest.raises(ValueError, match="There are extra or missing"):
Coeval.read(fname, direc=test_direc, safe=True)

with pytest.warns(UserWarning, match="There are extra or missing"):
cv2 = Coeval.read(fname, direc=test_direc, safe=False)

# check that the fake params didn't show up in the struct
assert not hasattr(cv2.user_params, "NotARealParameter")
assert "NotARealGlobal" not in cv2.global_params.keys()

# check that missing fields are set to default
assert cv2.user_params.BOX_LEN == UserParams._defaults_["BOX_LEN"]
assert cv2.global_params["OPTIMIZE_MIN_MASS"] == global_params.OPTIMIZE_MIN_MASS

# check that the fields which are good are read in the struct
assert all(
getattr(cv2.user_params, k) == getattr(coeval.user_params, k)
for k in UserParams._defaults_.keys()
if k != "BOX_LEN"
)
assert all(
cv2.global_params[k] == coeval.global_params[k]
for k in global_params.keys()
if k != "OPTIMIZE_MIN_MASS"
)


def test_lightcone_roundtrip(test_direc, lc):
fname = lc.save(direc=test_direc)
lc2 = LightCone.read(fname)
Expand Down

0 comments on commit 44b85e3

Please sign in to comment.