Skip to content

Commit

Permalink
Merge pull request #113 from ArgonneCPAC/defaults_migration
Browse files Browse the repository at this point in the history
Eliminate inconsistent usage of DEFAULT_MAH_PARAMS
  • Loading branch information
aphearin authored Jan 15, 2024
2 parents eeb79ed + b4cfe0d commit 88401c2
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 49 deletions.
49 changes: 49 additions & 0 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: linting

on:
push:
branches:
- main
pull_request: null

jobs:
tests:
name: tests
runs-on: "ubuntu-latest"

steps:
- uses: actions/checkout@v2

- uses: conda-incubator/setup-miniconda@v2
with:
python-version: 3.9
channels: conda-forge,defaults
channel-priority: strict
show-channel-urls: true
miniforge-version: latest
miniforge-variant: Mambaforge

- name: configure conda and install code
# Test against latest releases of each code in the dependency chain
shell: bash -l {0}
run: |
conda config --set always_yes yes
mamba install --quiet \
--file=requirements.txt
python -m pip install --no-deps -e .
mamba install -y -q \
flake8 \
pytest \
pytest-xdist \
pytest-cov \
pip \
setuptools \
"setuptools_scm>=7,<8" \
python-build \
flake8-pyproject
python -m pip install --no-build-isolation --no-deps -e .
- name: lint
shell: bash -l {0}
run: |
flake8 diffmah
4 changes: 4 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: tests

on:
workflow_dispatch: null
schedule:
# Runs "every Monday & Thursday at 3:05am Central"
- cron: '5 8 * * 1,4'
push:
branches:
- main
Expand Down
3 changes: 2 additions & 1 deletion diffmah/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# flake8: noqa

from ._version import __version__
from .individual_halo_assembly import calc_halo_history
from .defaults import DEFAULT_MAH_PARAMS, MAH_K, DiffmahParams
from .individual_halo_assembly import calc_halo_history, mah_halopop, mah_singlehalo
from .monte_carlo_diffmah_hiz import mc_diffmah_params_hiz
from .monte_carlo_halo_population import mc_halo_population
6 changes: 5 additions & 1 deletion diffmah/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
LGT0 = np.log10(TODAY)


DEFAULT_MAH_PDICT = OrderedDict(logmp=12.0, logtc=0.05, early_index=2.5, late_index=1.0)
DEFAULT_MAH_PDICT = OrderedDict(
logmp=12.0, logtc=0.05, early_index=2.6137643, late_index=0.12692805
)
DiffmahParams = namedtuple("DiffmahParams", list(DEFAULT_MAH_PDICT.keys()))
DEFAULT_MAH_PARAMS = DiffmahParams(*list(DEFAULT_MAH_PDICT.values()))

MAH_K = 3.5
23 changes: 15 additions & 8 deletions diffmah/fit_mah_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
from jax import jit as jjit
from jax import numpy as jnp
from jax import value_and_grad
from .individual_halo_assembly import DEFAULT_MAH_PARAMS
from .individual_halo_assembly import _u_rolling_plaw_vs_logt, _get_early_late

from .defaults import DEFAULT_MAH_PARAMS, MAH_K
from .individual_halo_assembly import (
_get_early_late,
_get_ue_ul,
_u_rolling_plaw_vs_logt,
)

T_FIT_MIN = 1.0
DLOGM_CUT = 2.5
Expand All @@ -14,10 +19,9 @@ def get_outline(halo_id, loss_data, p_best, loss_best):
"""Return the string storing fitting results that will be written to disk"""
logtc, ue, ul = p_best
logt0, u_k, logm0 = loss_data[-3:]
t0 = 10 ** logt0
t0 = 10**logt0
early, late = _get_early_late(ue, ul)
fixed_k = DEFAULT_MAH_PARAMS["mah_k"]
_d = np.array((logm0, logtc, fixed_k, early, late)).astype("f4")
_d = np.array((logm0, logtc, MAH_K, early, late)).astype("f4")
data_out = (halo_id, *_d, t0, float(loss_best))
out = str(halo_id) + " " + " ".join(["{:.5e}".format(x) for x in data_out[1:]])
return out + "\n"
Expand Down Expand Up @@ -105,12 +109,15 @@ def get_loss_data(
t_fit_min,
)
logmp_init = log_mah_sim[-1]
lgtc_init, fixed_k, ue_init, ud_init = list(DEFAULT_MAH_PARAMS.values())
p_init = np.array((lgtc_init, ue_init, ud_init)).astype("f4")
lgtc_init = DEFAULT_MAH_PARAMS.logtc
ue_init, ul_init = _get_ue_ul(
DEFAULT_MAH_PARAMS.early_index, DEFAULT_MAH_PARAMS.late_index
)
p_init = np.array((lgtc_init, ue_init, ul_init)).astype("f4")

logt0 = np.log10(t_sim[-1])

loss_data = (logt_target, log_mah_target, logt0, fixed_k, logmp_init)
loss_data = (logt_target, log_mah_target, logt0, MAH_K, logmp_init)
return p_init, loss_data


Expand Down
11 changes: 4 additions & 7 deletions diffmah/halo_population_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
from jax import vmap
from jax.scipy.stats import multivariate_normal as jnorm

from .individual_halo_assembly import (
DEFAULT_MAH_PARAMS,
_calc_halo_history,
_get_early_late,
)
from .defaults import MAH_K
from .individual_halo_assembly import _calc_halo_history, _get_early_late
from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, LGT0, _get_mah_means_and_covs

CLIP = -10.0
Expand Down Expand Up @@ -51,7 +48,7 @@ def _get_bimodal_halo_history_kern(
mu_late,
cov_early,
cov_late,
k=DEFAULT_MAH_PARAMS["mah_k"],
k=MAH_K,
logtmp=LGT0,
):
early_arr, late_arr = _get_early_late(ue_arr, ul_arr)
Expand Down Expand Up @@ -133,7 +130,7 @@ def _get_bimodal_halo_history(
chol_ue_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ue_lgtc_late_yhi"],
chol_ul_lgtc_late_ylo=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_ylo"],
chol_ul_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_yhi"],
k=DEFAULT_MAH_PARAMS["mah_k"],
k=MAH_K,
logtmp=LGT0,
):
_res = _get_mah_means_and_covs(
Expand Down
39 changes: 35 additions & 4 deletions diffmah/individual_halo_assembly.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
"""Model for individual halo mass assembly based on a power-law with rolling index."""
from collections import OrderedDict

from jax import grad
from jax import jit as jjit
from jax import lax
from jax import numpy as jnp
from jax import vmap

from .defaults import LGT0, MAH_K
from .utils import get_1d_arrays

DEFAULT_MAH_PARAMS = OrderedDict(mah_logtc=0.05, mah_k=3.5, mah_ue=2.4, mah_ul=-2.0)

@jjit
def mah_singlehalo(mah_params, tarr, lgt0=LGT0):
lgtarr = jnp.log10(tarr)
dmhdt, log_mah = _calc_halo_history(
lgtarr,
lgt0,
mah_params.logmp,
mah_params.logtc,
MAH_K,
mah_params.early_index,
mah_params.late_index,
)
return dmhdt, log_mah


@jjit
def mah_halopop(mah_params, tarr, lgt0=LGT0):
lgtarr = jnp.log10(tarr)
dmhdt, log_mah = _calc_halopop_history(
lgtarr,
lgt0,
mah_params.logmp,
mah_params.logtc,
MAH_K,
mah_params.early_index,
mah_params.late_index,
)
return dmhdt, log_mah


def calc_halo_history(t, t0, logmp, tauc, early, late, k=DEFAULT_MAH_PARAMS["mah_k"]):
def calc_halo_history(t, t0, logmp, tauc, early, late, k=MAH_K):
"""Calculate individual halo assembly histories.
Parameters
Expand Down Expand Up @@ -101,6 +128,10 @@ def _calc_halo_history(logt, logt0, logmp, logtc, k, early, late):
return dmhdt, log_mah


_YO = (None, None, 0, 0, None, 0, 0)
_calc_halopop_history = jjit(vmap(_calc_halo_history, in_axes=_YO))


@jjit
def _softplus(x):
return jnp.log(1 + lax.exp(x))
Expand Down
9 changes: 2 additions & 7 deletions diffmah/monte_carlo_halo_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@
from jax import random as jran
from jax import vmap

from .individual_halo_assembly import (
DEFAULT_MAH_PARAMS,
_calc_halo_history,
_get_early_late,
)
from .defaults import MAH_K
from .individual_halo_assembly import _calc_halo_history, _get_early_late
from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, _get_mah_means_and_covs

MAH_K = DEFAULT_MAH_PARAMS["mah_k"]

_A = (None, None, 0, 0, None, 0, 0)
_calc_halo_history_vmap = jjit(vmap(_calc_halo_history, in_axes=_A))

Expand Down
5 changes: 2 additions & 3 deletions diffmah/rockstar_pdf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from jax import numpy as jnp
from jax import vmap

from .individual_halo_assembly import DEFAULT_MAH_PARAMS
from .defaults import MAH_K
from .utils import get_cholesky_from_params

TODAY = 13.8
LGT0 = jnp.log10(TODAY)
K = DEFAULT_MAH_PARAMS["mah_k"]

_LGM_X0, LGM_K = 13.0, 0.5

Expand Down Expand Up @@ -352,7 +351,7 @@ def _get_mah_means_and_covs(
chol_ue_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ue_lgtc_late_yhi"],
chol_ul_lgtc_late_ylo=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_ylo"],
chol_ul_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_yhi"],
k=DEFAULT_MAH_PARAMS["mah_k"],
k=MAH_K,
logtmp=LGT0,
):
frac_late = frac_late_forming(logmp_arr, frac_late_ylo, frac_late_yhi)
Expand Down
26 changes: 26 additions & 0 deletions diffmah/tests/test_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
"""
import numpy as np


def test_default_mah_params_imports_from_top_level_and_is_frozen():
from .. import DEFAULT_MAH_PARAMS

assert np.allclose(DEFAULT_MAH_PARAMS, (12.0, 0.05, 2.6137643, 0.12692805))


def test_mah_k_imports_from_top_level():
from .. import MAH_K

assert np.allclose(MAH_K, 3.5)


def test_mah_halopop_imports_from_top_level():
from .. import DEFAULT_MAH_PARAMS, DiffmahParams, mah_halopop

tarr = np.linspace(0.1, 13.7, 100)
ngals = 150
zz = np.zeros(ngals)
mah_params_halopop = DiffmahParams(*[x + zz for x in DEFAULT_MAH_PARAMS])
dmhdt, log_mah = mah_halopop(mah_params_halopop, tarr)
assert log_mah.shape == dmhdt.shape
51 changes: 36 additions & 15 deletions diffmah/tests/test_individual_halo_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import numpy as np
from jax import numpy as jnp

from ..defaults import DEFAULT_MAH_PARAMS, MAH_K, DiffmahParams
from ..individual_halo_assembly import (
DEFAULT_MAH_PARAMS,
_calc_halo_history,
_calc_halo_history_scalar,
_get_early_late,
_power_law_index_vs_logt,
mah_halopop,
mah_singlehalo,
)
from ..rockstar_pdf_model import _get_mean_mah_params_early, _get_mean_mah_params_late

Expand All @@ -22,16 +24,15 @@ def test_calc_halo_history_evaluates():
tarr = np.linspace(0.1, 14, 500)
logt = np.log10(tarr)
logtmp = logt[-1]
logmp = 12.0
lgtc, k, ue, ul = list(DEFAULT_MAH_PARAMS.values())
early, late = _get_early_late(ue, ul)
dmhdt, log_mah = _calc_halo_history(logt, logtmp, logmp, lgtc, k, early, late)
lgmp, lgtc, early, late = DEFAULT_MAH_PARAMS
dmhdt, log_mah = _calc_halo_history(logt, logtmp, lgmp, lgtc, MAH_K, early, late)
assert np.all(np.isfinite(dmhdt))
assert np.all(np.isfinite(log_mah))


def test_rolling_index_agrees_with_hard_coded_expectation():
lgmp_test = 12.5

k = DEFAULT_MAH_PARAMS["mah_k"]
logt_bn = "logt_testing_array.dat"
logt = np.loadtxt(os.path.join(DDRN, logt_bn))
logt0 = logt[-1]
Expand All @@ -44,19 +45,19 @@ def test_rolling_index_agrees_with_hard_coded_expectation():

indx_e_bn = "rolling_plaw_index_vs_time_rockstar_default_logmp_{0:.1f}_early.dat"
index_early_correct = np.loadtxt(os.path.join(DDRN, indx_e_bn.format(lgmp_test)))
index_early = _power_law_index_vs_logt(logt, lgtc_e, k, early_e, late_e)
index_early = _power_law_index_vs_logt(logt, lgtc_e, MAH_K, early_e, late_e)
assert np.allclose(index_early_correct, index_early, rtol=0.01)

indx_l_bn = "rolling_plaw_index_vs_time_rockstar_default_logmp_{0:.1f}_late.dat"
index_late_correct = np.loadtxt(os.path.join(DDRN, indx_l_bn.format(lgmp_test)))
index_late = _power_law_index_vs_logt(logt, lgtc_l, k, early_l, late_l)
index_late = _power_law_index_vs_logt(logt, lgtc_l, MAH_K, early_l, late_l)
assert np.allclose(index_late_correct, index_late, rtol=0.01)

dmhdt_e, log_mah_e = _calc_halo_history(
logt, logt0, lgmp_test, lgtc_e, k, early_e, late_e
logt, logt0, lgmp_test, lgtc_e, MAH_K, early_e, late_e
)
dmhdt_l, log_mah_l = _calc_halo_history(
logt, logt0, lgmp_test, lgtc_l, k, early_l, late_l
logt, logt0, lgmp_test, lgtc_l, MAH_K, early_l, late_l
)

log_mah_e_bn = "log_mah_vs_time_rockstar_default_logmp_{0:.1f}_early.dat"
Expand All @@ -72,14 +73,34 @@ def test_calc_halo_history_scalar_agrees_with_vmap():
tarr = np.linspace(0.1, 14, 15)
logt = np.log10(tarr)
logtmp = logt[-1]
logmp = 12.0
lgtc, k, ue, ul = list(DEFAULT_MAH_PARAMS.values())
early, late = _get_early_late(ue, ul)
dmhdt, log_mah = _calc_halo_history(logt, logtmp, logmp, lgtc, k, early, late)
lgmp, lgtc, early, late = DEFAULT_MAH_PARAMS
dmhdt, log_mah = _calc_halo_history(logt, logtmp, lgmp, lgtc, MAH_K, early, late)

for i, t in enumerate(tarr):
lgt_i = jnp.log10(t)
res = _calc_halo_history_scalar(lgt_i, logtmp, logmp, lgtc, k, early, late)
res = _calc_halo_history_scalar(lgt_i, logtmp, lgmp, lgtc, MAH_K, early, late)
dmhdt_i, log_mah_i = res
assert np.allclose(dmhdt[i], dmhdt_i)
assert np.allclose(log_mah[i], log_mah_i)


def test_mah_singlehalo_evaluates():
nt = 100
tarr = np.linspace(0.1, 13.8, nt)
dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, tarr)
assert dmhdt.shape == tarr.shape
assert log_mah.shape == dmhdt.shape
assert log_mah[-1] == DEFAULT_MAH_PARAMS.logmp


def test_mah_halopop_evaluates():
nt = 100
tarr = np.linspace(0.1, 13.8, nt)

ngals = 150
zz = np.zeros(ngals)
mah_params_halopop = DiffmahParams(*[zz + p for p in DEFAULT_MAH_PARAMS])
dmhdt, log_mah = mah_halopop(mah_params_halopop, tarr)
assert dmhdt.shape == (ngals, nt)
assert log_mah.shape == dmhdt.shape
assert np.allclose(log_mah[:, -1], DEFAULT_MAH_PARAMS.logmp)
Loading

0 comments on commit 88401c2

Please sign in to comment.