Skip to content

Commit

Permalink
Eliminate inconsistent usage of DEFAULT_MAH_PARAMS throughout rest of…
Browse files Browse the repository at this point in the history
… package
  • Loading branch information
aphearin committed Jan 15, 2024
1 parent eeb79ed commit ba66137
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 47 deletions.
2 changes: 2 additions & 0 deletions diffmah/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@
DEFAULT_MAH_PDICT = OrderedDict(logmp=12.0, logtc=0.05, early_index=2.5, late_index=1.0)
DiffmahParams = namedtuple("DiffmahParams", list(DEFAULT_MAH_PDICT.keys()))
DEFAULT_MAH_PARAMS = DiffmahParams(*list(DEFAULT_MAH_PDICT.values()))

MAH_K = 3.5
23 changes: 15 additions & 8 deletions diffmah/fit_mah_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
from jax import jit as jjit
from jax import numpy as jnp
from jax import value_and_grad
from .individual_halo_assembly import DEFAULT_MAH_PARAMS
from .individual_halo_assembly import _u_rolling_plaw_vs_logt, _get_early_late

from .defaults import DEFAULT_MAH_PARAMS, MAH_K
from .individual_halo_assembly import (
_get_early_late,
_get_ue_ul,
_u_rolling_plaw_vs_logt,
)

T_FIT_MIN = 1.0
DLOGM_CUT = 2.5
Expand All @@ -14,10 +19,9 @@ def get_outline(halo_id, loss_data, p_best, loss_best):
"""Return the string storing fitting results that will be written to disk"""
logtc, ue, ul = p_best
logt0, u_k, logm0 = loss_data[-3:]
t0 = 10 ** logt0
t0 = 10**logt0

Check warning on line 22 in diffmah/fit_mah_helpers.py

View check run for this annotation

Codecov / codecov/patch

diffmah/fit_mah_helpers.py#L22

Added line #L22 was not covered by tests
early, late = _get_early_late(ue, ul)
fixed_k = DEFAULT_MAH_PARAMS["mah_k"]
_d = np.array((logm0, logtc, fixed_k, early, late)).astype("f4")
_d = np.array((logm0, logtc, MAH_K, early, late)).astype("f4")

Check warning on line 24 in diffmah/fit_mah_helpers.py

View check run for this annotation

Codecov / codecov/patch

diffmah/fit_mah_helpers.py#L24

Added line #L24 was not covered by tests
data_out = (halo_id, *_d, t0, float(loss_best))
out = str(halo_id) + " " + " ".join(["{:.5e}".format(x) for x in data_out[1:]])
return out + "\n"
Expand Down Expand Up @@ -105,12 +109,15 @@ def get_loss_data(
t_fit_min,
)
logmp_init = log_mah_sim[-1]
lgtc_init, fixed_k, ue_init, ud_init = list(DEFAULT_MAH_PARAMS.values())
p_init = np.array((lgtc_init, ue_init, ud_init)).astype("f4")
lgtc_init = DEFAULT_MAH_PARAMS.logtc
ue_init, ul_init = _get_ue_ul(

Check warning on line 113 in diffmah/fit_mah_helpers.py

View check run for this annotation

Codecov / codecov/patch

diffmah/fit_mah_helpers.py#L112-L113

Added lines #L112 - L113 were not covered by tests
DEFAULT_MAH_PARAMS.early_index, DEFAULT_MAH_PARAMS.late_index
)
p_init = np.array((lgtc_init, ue_init, ul_init)).astype("f4")

Check warning on line 116 in diffmah/fit_mah_helpers.py

View check run for this annotation

Codecov / codecov/patch

diffmah/fit_mah_helpers.py#L116

Added line #L116 was not covered by tests

logt0 = np.log10(t_sim[-1])

loss_data = (logt_target, log_mah_target, logt0, fixed_k, logmp_init)
loss_data = (logt_target, log_mah_target, logt0, MAH_K, logmp_init)

Check warning on line 120 in diffmah/fit_mah_helpers.py

View check run for this annotation

Codecov / codecov/patch

diffmah/fit_mah_helpers.py#L120

Added line #L120 was not covered by tests
return p_init, loss_data


Expand Down
11 changes: 4 additions & 7 deletions diffmah/halo_population_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
from jax import vmap
from jax.scipy.stats import multivariate_normal as jnorm

from .individual_halo_assembly import (
DEFAULT_MAH_PARAMS,
_calc_halo_history,
_get_early_late,
)
from .defaults import MAH_K
from .individual_halo_assembly import _calc_halo_history, _get_early_late
from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, LGT0, _get_mah_means_and_covs

CLIP = -10.0
Expand Down Expand Up @@ -51,7 +48,7 @@ def _get_bimodal_halo_history_kern(
mu_late,
cov_early,
cov_late,
k=DEFAULT_MAH_PARAMS["mah_k"],
k=MAH_K,
logtmp=LGT0,
):
early_arr, late_arr = _get_early_late(ue_arr, ul_arr)
Expand Down Expand Up @@ -133,7 +130,7 @@ def _get_bimodal_halo_history(
chol_ue_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ue_lgtc_late_yhi"],
chol_ul_lgtc_late_ylo=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_ylo"],
chol_ul_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_yhi"],
k=DEFAULT_MAH_PARAMS["mah_k"],
k=MAH_K,
logtmp=LGT0,
):
_res = _get_mah_means_and_covs(
Expand Down
39 changes: 35 additions & 4 deletions diffmah/individual_halo_assembly.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
"""Model for individual halo mass assembly based on a power-law with rolling index."""
from collections import OrderedDict

from jax import grad
from jax import jit as jjit
from jax import lax
from jax import numpy as jnp
from jax import vmap

from .defaults import LGT0, MAH_K
from .utils import get_1d_arrays

DEFAULT_MAH_PARAMS = OrderedDict(mah_logtc=0.05, mah_k=3.5, mah_ue=2.4, mah_ul=-2.0)

@jjit
def mah_singlehalo(mah_params, tarr, lgt0=LGT0):
lgtarr = jnp.log10(tarr)
dmhdt, log_mah = _calc_halo_history(

Check warning on line 15 in diffmah/individual_halo_assembly.py

View check run for this annotation

Codecov / codecov/patch

diffmah/individual_halo_assembly.py#L14-L15

Added lines #L14 - L15 were not covered by tests
lgtarr,
lgt0,
mah_params.logmp,
mah_params.logtc,
MAH_K,
mah_params.early_index,
mah_params.late_index,
)
return dmhdt, log_mah

Check warning on line 24 in diffmah/individual_halo_assembly.py

View check run for this annotation

Codecov / codecov/patch

diffmah/individual_halo_assembly.py#L24

Added line #L24 was not covered by tests


@jjit
def mah_halopop(mah_params, tarr, lgt0=LGT0):
lgtarr = jnp.log10(tarr)
dmhdt, log_mah = _calc_halopop_history(

Check warning on line 30 in diffmah/individual_halo_assembly.py

View check run for this annotation

Codecov / codecov/patch

diffmah/individual_halo_assembly.py#L29-L30

Added lines #L29 - L30 were not covered by tests
lgtarr,
lgt0,
mah_params.logmp,
mah_params.logtc,
MAH_K,
mah_params.early_index,
mah_params.late_index,
)
return dmhdt, log_mah

Check warning on line 39 in diffmah/individual_halo_assembly.py

View check run for this annotation

Codecov / codecov/patch

diffmah/individual_halo_assembly.py#L39

Added line #L39 was not covered by tests


def calc_halo_history(t, t0, logmp, tauc, early, late, k=DEFAULT_MAH_PARAMS["mah_k"]):
def calc_halo_history(t, t0, logmp, tauc, early, late, k=MAH_K):
"""Calculate individual halo assembly histories.
Parameters
Expand Down Expand Up @@ -101,6 +128,10 @@ def _calc_halo_history(logt, logt0, logmp, logtc, k, early, late):
return dmhdt, log_mah


_YO = (None, None, 0, 0, None, 0, 0)
_calc_halopop_history = jjit(vmap(_calc_halo_history, in_axes=_YO))


@jjit
def _softplus(x):
return jnp.log(1 + lax.exp(x))
Expand Down
9 changes: 2 additions & 7 deletions diffmah/monte_carlo_halo_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@
from jax import random as jran
from jax import vmap

from .individual_halo_assembly import (
DEFAULT_MAH_PARAMS,
_calc_halo_history,
_get_early_late,
)
from .defaults import MAH_K
from .individual_halo_assembly import _calc_halo_history, _get_early_late
from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, _get_mah_means_and_covs

MAH_K = DEFAULT_MAH_PARAMS["mah_k"]

_A = (None, None, 0, 0, None, 0, 0)
_calc_halo_history_vmap = jjit(vmap(_calc_halo_history, in_axes=_A))

Expand Down
5 changes: 2 additions & 3 deletions diffmah/rockstar_pdf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from jax import numpy as jnp
from jax import vmap

from .individual_halo_assembly import DEFAULT_MAH_PARAMS
from .defaults import MAH_K
from .utils import get_cholesky_from_params

TODAY = 13.8
LGT0 = jnp.log10(TODAY)
K = DEFAULT_MAH_PARAMS["mah_k"]

_LGM_X0, LGM_K = 13.0, 0.5

Expand Down Expand Up @@ -352,7 +351,7 @@ def _get_mah_means_and_covs(
chol_ue_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ue_lgtc_late_yhi"],
chol_ul_lgtc_late_ylo=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_ylo"],
chol_ul_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_yhi"],
k=DEFAULT_MAH_PARAMS["mah_k"],
k=MAH_K,
logtmp=LGT0,
):
frac_late = frac_late_forming(logmp_arr, frac_late_ylo, frac_late_yhi)
Expand Down
27 changes: 12 additions & 15 deletions diffmah/tests/test_individual_halo_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import numpy as np
from jax import numpy as jnp

from ..defaults import DEFAULT_MAH_PARAMS, MAH_K
from ..individual_halo_assembly import (
DEFAULT_MAH_PARAMS,
_calc_halo_history,
_calc_halo_history_scalar,
_get_early_late,
Expand All @@ -22,16 +22,15 @@ def test_calc_halo_history_evaluates():
tarr = np.linspace(0.1, 14, 500)
logt = np.log10(tarr)
logtmp = logt[-1]
logmp = 12.0
lgtc, k, ue, ul = list(DEFAULT_MAH_PARAMS.values())
early, late = _get_early_late(ue, ul)
dmhdt, log_mah = _calc_halo_history(logt, logtmp, logmp, lgtc, k, early, late)
lgmp, lgtc, early, late = DEFAULT_MAH_PARAMS
dmhdt, log_mah = _calc_halo_history(logt, logtmp, lgmp, lgtc, MAH_K, early, late)
assert np.all(np.isfinite(dmhdt))
assert np.all(np.isfinite(log_mah))


def test_rolling_index_agrees_with_hard_coded_expectation():
lgmp_test = 12.5

k = DEFAULT_MAH_PARAMS["mah_k"]
logt_bn = "logt_testing_array.dat"
logt = np.loadtxt(os.path.join(DDRN, logt_bn))
logt0 = logt[-1]
Expand All @@ -44,19 +43,19 @@ def test_rolling_index_agrees_with_hard_coded_expectation():

indx_e_bn = "rolling_plaw_index_vs_time_rockstar_default_logmp_{0:.1f}_early.dat"
index_early_correct = np.loadtxt(os.path.join(DDRN, indx_e_bn.format(lgmp_test)))
index_early = _power_law_index_vs_logt(logt, lgtc_e, k, early_e, late_e)
index_early = _power_law_index_vs_logt(logt, lgtc_e, MAH_K, early_e, late_e)
assert np.allclose(index_early_correct, index_early, rtol=0.01)

indx_l_bn = "rolling_plaw_index_vs_time_rockstar_default_logmp_{0:.1f}_late.dat"
index_late_correct = np.loadtxt(os.path.join(DDRN, indx_l_bn.format(lgmp_test)))
index_late = _power_law_index_vs_logt(logt, lgtc_l, k, early_l, late_l)
index_late = _power_law_index_vs_logt(logt, lgtc_l, MAH_K, early_l, late_l)
assert np.allclose(index_late_correct, index_late, rtol=0.01)

dmhdt_e, log_mah_e = _calc_halo_history(
logt, logt0, lgmp_test, lgtc_e, k, early_e, late_e
logt, logt0, lgmp_test, lgtc_e, MAH_K, early_e, late_e
)
dmhdt_l, log_mah_l = _calc_halo_history(
logt, logt0, lgmp_test, lgtc_l, k, early_l, late_l
logt, logt0, lgmp_test, lgtc_l, MAH_K, early_l, late_l
)

log_mah_e_bn = "log_mah_vs_time_rockstar_default_logmp_{0:.1f}_early.dat"
Expand All @@ -72,14 +71,12 @@ def test_calc_halo_history_scalar_agrees_with_vmap():
tarr = np.linspace(0.1, 14, 15)
logt = np.log10(tarr)
logtmp = logt[-1]
logmp = 12.0
lgtc, k, ue, ul = list(DEFAULT_MAH_PARAMS.values())
early, late = _get_early_late(ue, ul)
dmhdt, log_mah = _calc_halo_history(logt, logtmp, logmp, lgtc, k, early, late)
lgmp, lgtc, early, late = DEFAULT_MAH_PARAMS
dmhdt, log_mah = _calc_halo_history(logt, logtmp, lgmp, lgtc, MAH_K, early, late)

for i, t in enumerate(tarr):
lgt_i = jnp.log10(t)
res = _calc_halo_history_scalar(lgt_i, logtmp, logmp, lgtc, k, early, late)
res = _calc_halo_history_scalar(lgt_i, logtmp, lgmp, lgtc, MAH_K, early, late)
dmhdt_i, log_mah_i = res
assert np.allclose(dmhdt[i], dmhdt_i)
assert np.allclose(log_mah[i], log_mah_i)
5 changes: 2 additions & 3 deletions diffmah/tng_pdf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from jax import numpy as jnp
from jax import vmap

from .individual_halo_assembly import DEFAULT_MAH_PARAMS
from .defaults import MAH_K
from .utils import get_cholesky_from_params

TODAY = 13.8
LGT0 = jnp.log10(TODAY)
K = DEFAULT_MAH_PARAMS["mah_k"]

_LGM_X0, LGM_K = 13.0, 0.5

Expand Down Expand Up @@ -353,7 +352,7 @@ def _get_mah_means_and_covs(
chol_ue_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ue_lgtc_late_yhi"],
chol_ul_lgtc_late_ylo=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_ylo"],
chol_ul_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_yhi"],
k=DEFAULT_MAH_PARAMS["mah_k"],
k=MAH_K,
logtmp=LGT0,
):
frac_late = frac_late_forming(logmp_arr, frac_late_ylo, frac_late_yhi)
Expand Down

0 comments on commit ba66137

Please sign in to comment.