diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml new file mode 100644 index 0000000..1946039 --- /dev/null +++ b/.github/workflows/linting.yml @@ -0,0 +1,49 @@ +name: linting + +on: + push: + branches: + - main + pull_request: null + +jobs: + tests: + name: tests + runs-on: "ubuntu-latest" + + steps: + - uses: actions/checkout@v2 + + - uses: conda-incubator/setup-miniconda@v2 + with: + python-version: 3.9 + channels: conda-forge,defaults + channel-priority: strict + show-channel-urls: true + miniforge-version: latest + miniforge-variant: Mambaforge + + - name: configure conda and install code + # Test against latest releases of each code in the dependency chain + shell: bash -l {0} + run: | + conda config --set always_yes yes + mamba install --quiet \ + --file=requirements.txt + python -m pip install --no-deps -e . + mamba install -y -q \ + flake8 \ + pytest \ + pytest-xdist \ + pytest-cov \ + pip \ + setuptools \ + "setuptools_scm>=7,<8" \ + python-build \ + flake8-pyproject + python -m pip install --no-build-isolation --no-deps -e . + + - name: lint + shell: bash -l {0} + run: | + flake8 diffmah diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fcf7cee..c77d210 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,6 +1,10 @@ name: tests on: + workflow_dispatch: null + schedule: + # Runs "every Monday & Thursday at 3:05am Central" + - cron: '5 8 * * 1,4' push: branches: - main diff --git a/diffmah/__init__.py b/diffmah/__init__.py index dbb92e2..c91e4bb 100644 --- a/diffmah/__init__.py +++ b/diffmah/__init__.py @@ -3,6 +3,7 @@ # flake8: noqa from ._version import __version__ -from .individual_halo_assembly import calc_halo_history +from .defaults import DEFAULT_MAH_PARAMS, MAH_K, DiffmahParams +from .individual_halo_assembly import calc_halo_history, mah_halopop, mah_singlehalo from .monte_carlo_diffmah_hiz import mc_diffmah_params_hiz from .monte_carlo_halo_population import mc_halo_population diff --git a/diffmah/defaults.py b/diffmah/defaults.py index cf68dff..03d2297 100644 --- a/diffmah/defaults.py +++ b/diffmah/defaults.py @@ -9,6 +9,10 @@ LGT0 = np.log10(TODAY) -DEFAULT_MAH_PDICT = OrderedDict(logmp=12.0, logtc=0.05, early_index=2.5, late_index=1.0) +DEFAULT_MAH_PDICT = OrderedDict( + logmp=12.0, logtc=0.05, early_index=2.6137643, late_index=0.12692805 +) DiffmahParams = namedtuple("DiffmahParams", list(DEFAULT_MAH_PDICT.keys())) DEFAULT_MAH_PARAMS = DiffmahParams(*list(DEFAULT_MAH_PDICT.values())) + +MAH_K = 3.5 diff --git a/diffmah/fit_mah_helpers.py b/diffmah/fit_mah_helpers.py index 5e30d90..3960154 100644 --- a/diffmah/fit_mah_helpers.py +++ b/diffmah/fit_mah_helpers.py @@ -3,8 +3,13 @@ from jax import jit as jjit from jax import numpy as jnp from jax import value_and_grad -from .individual_halo_assembly import DEFAULT_MAH_PARAMS -from .individual_halo_assembly import _u_rolling_plaw_vs_logt, _get_early_late + +from .defaults import DEFAULT_MAH_PARAMS, MAH_K +from .individual_halo_assembly import ( + _get_early_late, + _get_ue_ul, + _u_rolling_plaw_vs_logt, +) T_FIT_MIN = 1.0 DLOGM_CUT = 2.5 @@ -14,10 +19,9 @@ def get_outline(halo_id, loss_data, p_best, loss_best): """Return the string storing fitting results that will be written to disk""" logtc, ue, ul = p_best logt0, u_k, logm0 = loss_data[-3:] - t0 = 10 ** logt0 + t0 = 10**logt0 early, late = _get_early_late(ue, ul) - fixed_k = DEFAULT_MAH_PARAMS["mah_k"] - _d = np.array((logm0, logtc, fixed_k, early, late)).astype("f4") + _d = np.array((logm0, logtc, MAH_K, early, late)).astype("f4") data_out = (halo_id, *_d, t0, float(loss_best)) out = str(halo_id) + " " + " ".join(["{:.5e}".format(x) for x in data_out[1:]]) return out + "\n" @@ -105,12 +109,15 @@ def get_loss_data( t_fit_min, ) logmp_init = log_mah_sim[-1] - lgtc_init, fixed_k, ue_init, ud_init = list(DEFAULT_MAH_PARAMS.values()) - p_init = np.array((lgtc_init, ue_init, ud_init)).astype("f4") + lgtc_init = DEFAULT_MAH_PARAMS.logtc + ue_init, ul_init = _get_ue_ul( + DEFAULT_MAH_PARAMS.early_index, DEFAULT_MAH_PARAMS.late_index + ) + p_init = np.array((lgtc_init, ue_init, ul_init)).astype("f4") logt0 = np.log10(t_sim[-1]) - loss_data = (logt_target, log_mah_target, logt0, fixed_k, logmp_init) + loss_data = (logt_target, log_mah_target, logt0, MAH_K, logmp_init) return p_init, loss_data diff --git a/diffmah/halo_population_assembly.py b/diffmah/halo_population_assembly.py index 6d57309..e243d88 100644 --- a/diffmah/halo_population_assembly.py +++ b/diffmah/halo_population_assembly.py @@ -5,11 +5,8 @@ from jax import vmap from jax.scipy.stats import multivariate_normal as jnorm -from .individual_halo_assembly import ( - DEFAULT_MAH_PARAMS, - _calc_halo_history, - _get_early_late, -) +from .defaults import MAH_K +from .individual_halo_assembly import _calc_halo_history, _get_early_late from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, LGT0, _get_mah_means_and_covs CLIP = -10.0 @@ -51,7 +48,7 @@ def _get_bimodal_halo_history_kern( mu_late, cov_early, cov_late, - k=DEFAULT_MAH_PARAMS["mah_k"], + k=MAH_K, logtmp=LGT0, ): early_arr, late_arr = _get_early_late(ue_arr, ul_arr) @@ -133,7 +130,7 @@ def _get_bimodal_halo_history( chol_ue_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ue_lgtc_late_yhi"], chol_ul_lgtc_late_ylo=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_ylo"], chol_ul_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_yhi"], - k=DEFAULT_MAH_PARAMS["mah_k"], + k=MAH_K, logtmp=LGT0, ): _res = _get_mah_means_and_covs( diff --git a/diffmah/individual_halo_assembly.py b/diffmah/individual_halo_assembly.py index 9843877..5405785 100644 --- a/diffmah/individual_halo_assembly.py +++ b/diffmah/individual_halo_assembly.py @@ -1,18 +1,45 @@ """Model for individual halo mass assembly based on a power-law with rolling index.""" -from collections import OrderedDict - from jax import grad from jax import jit as jjit from jax import lax from jax import numpy as jnp from jax import vmap +from .defaults import LGT0, MAH_K from .utils import get_1d_arrays -DEFAULT_MAH_PARAMS = OrderedDict(mah_logtc=0.05, mah_k=3.5, mah_ue=2.4, mah_ul=-2.0) + +@jjit +def mah_singlehalo(mah_params, tarr, lgt0=LGT0): + lgtarr = jnp.log10(tarr) + dmhdt, log_mah = _calc_halo_history( + lgtarr, + lgt0, + mah_params.logmp, + mah_params.logtc, + MAH_K, + mah_params.early_index, + mah_params.late_index, + ) + return dmhdt, log_mah + + +@jjit +def mah_halopop(mah_params, tarr, lgt0=LGT0): + lgtarr = jnp.log10(tarr) + dmhdt, log_mah = _calc_halopop_history( + lgtarr, + lgt0, + mah_params.logmp, + mah_params.logtc, + MAH_K, + mah_params.early_index, + mah_params.late_index, + ) + return dmhdt, log_mah -def calc_halo_history(t, t0, logmp, tauc, early, late, k=DEFAULT_MAH_PARAMS["mah_k"]): +def calc_halo_history(t, t0, logmp, tauc, early, late, k=MAH_K): """Calculate individual halo assembly histories. Parameters @@ -101,6 +128,10 @@ def _calc_halo_history(logt, logt0, logmp, logtc, k, early, late): return dmhdt, log_mah +_YO = (None, None, 0, 0, None, 0, 0) +_calc_halopop_history = jjit(vmap(_calc_halo_history, in_axes=_YO)) + + @jjit def _softplus(x): return jnp.log(1 + lax.exp(x)) diff --git a/diffmah/monte_carlo_halo_population.py b/diffmah/monte_carlo_halo_population.py index f5d874e..d00dd94 100644 --- a/diffmah/monte_carlo_halo_population.py +++ b/diffmah/monte_carlo_halo_population.py @@ -8,15 +8,10 @@ from jax import random as jran from jax import vmap -from .individual_halo_assembly import ( - DEFAULT_MAH_PARAMS, - _calc_halo_history, - _get_early_late, -) +from .defaults import MAH_K +from .individual_halo_assembly import _calc_halo_history, _get_early_late from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, _get_mah_means_and_covs -MAH_K = DEFAULT_MAH_PARAMS["mah_k"] - _A = (None, None, 0, 0, None, 0, 0) _calc_halo_history_vmap = jjit(vmap(_calc_halo_history, in_axes=_A)) diff --git a/diffmah/rockstar_pdf_model.py b/diffmah/rockstar_pdf_model.py index 1fa7080..3054c51 100644 --- a/diffmah/rockstar_pdf_model.py +++ b/diffmah/rockstar_pdf_model.py @@ -6,12 +6,11 @@ from jax import numpy as jnp from jax import vmap -from .individual_halo_assembly import DEFAULT_MAH_PARAMS +from .defaults import MAH_K from .utils import get_cholesky_from_params TODAY = 13.8 LGT0 = jnp.log10(TODAY) -K = DEFAULT_MAH_PARAMS["mah_k"] _LGM_X0, LGM_K = 13.0, 0.5 @@ -352,7 +351,7 @@ def _get_mah_means_and_covs( chol_ue_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ue_lgtc_late_yhi"], chol_ul_lgtc_late_ylo=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_ylo"], chol_ul_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_yhi"], - k=DEFAULT_MAH_PARAMS["mah_k"], + k=MAH_K, logtmp=LGT0, ): frac_late = frac_late_forming(logmp_arr, frac_late_ylo, frac_late_yhi) diff --git a/diffmah/tests/test_defaults.py b/diffmah/tests/test_defaults.py new file mode 100644 index 0000000..ebf4fc8 --- /dev/null +++ b/diffmah/tests/test_defaults.py @@ -0,0 +1,26 @@ +""" +""" +import numpy as np + + +def test_default_mah_params_imports_from_top_level_and_is_frozen(): + from .. import DEFAULT_MAH_PARAMS + + assert np.allclose(DEFAULT_MAH_PARAMS, (12.0, 0.05, 2.6137643, 0.12692805)) + + +def test_mah_k_imports_from_top_level(): + from .. import MAH_K + + assert np.allclose(MAH_K, 3.5) + + +def test_mah_halopop_imports_from_top_level(): + from .. import DEFAULT_MAH_PARAMS, DiffmahParams, mah_halopop + + tarr = np.linspace(0.1, 13.7, 100) + ngals = 150 + zz = np.zeros(ngals) + mah_params_halopop = DiffmahParams(*[x + zz for x in DEFAULT_MAH_PARAMS]) + dmhdt, log_mah = mah_halopop(mah_params_halopop, tarr) + assert log_mah.shape == dmhdt.shape diff --git a/diffmah/tests/test_individual_halo_assembly.py b/diffmah/tests/test_individual_halo_assembly.py index d520d3a..ef52ef4 100644 --- a/diffmah/tests/test_individual_halo_assembly.py +++ b/diffmah/tests/test_individual_halo_assembly.py @@ -5,12 +5,14 @@ import numpy as np from jax import numpy as jnp +from ..defaults import DEFAULT_MAH_PARAMS, MAH_K, DiffmahParams from ..individual_halo_assembly import ( - DEFAULT_MAH_PARAMS, _calc_halo_history, _calc_halo_history_scalar, _get_early_late, _power_law_index_vs_logt, + mah_halopop, + mah_singlehalo, ) from ..rockstar_pdf_model import _get_mean_mah_params_early, _get_mean_mah_params_late @@ -22,16 +24,15 @@ def test_calc_halo_history_evaluates(): tarr = np.linspace(0.1, 14, 500) logt = np.log10(tarr) logtmp = logt[-1] - logmp = 12.0 - lgtc, k, ue, ul = list(DEFAULT_MAH_PARAMS.values()) - early, late = _get_early_late(ue, ul) - dmhdt, log_mah = _calc_halo_history(logt, logtmp, logmp, lgtc, k, early, late) + lgmp, lgtc, early, late = DEFAULT_MAH_PARAMS + dmhdt, log_mah = _calc_halo_history(logt, logtmp, lgmp, lgtc, MAH_K, early, late) + assert np.all(np.isfinite(dmhdt)) + assert np.all(np.isfinite(log_mah)) def test_rolling_index_agrees_with_hard_coded_expectation(): lgmp_test = 12.5 - k = DEFAULT_MAH_PARAMS["mah_k"] logt_bn = "logt_testing_array.dat" logt = np.loadtxt(os.path.join(DDRN, logt_bn)) logt0 = logt[-1] @@ -44,19 +45,19 @@ def test_rolling_index_agrees_with_hard_coded_expectation(): indx_e_bn = "rolling_plaw_index_vs_time_rockstar_default_logmp_{0:.1f}_early.dat" index_early_correct = np.loadtxt(os.path.join(DDRN, indx_e_bn.format(lgmp_test))) - index_early = _power_law_index_vs_logt(logt, lgtc_e, k, early_e, late_e) + index_early = _power_law_index_vs_logt(logt, lgtc_e, MAH_K, early_e, late_e) assert np.allclose(index_early_correct, index_early, rtol=0.01) indx_l_bn = "rolling_plaw_index_vs_time_rockstar_default_logmp_{0:.1f}_late.dat" index_late_correct = np.loadtxt(os.path.join(DDRN, indx_l_bn.format(lgmp_test))) - index_late = _power_law_index_vs_logt(logt, lgtc_l, k, early_l, late_l) + index_late = _power_law_index_vs_logt(logt, lgtc_l, MAH_K, early_l, late_l) assert np.allclose(index_late_correct, index_late, rtol=0.01) dmhdt_e, log_mah_e = _calc_halo_history( - logt, logt0, lgmp_test, lgtc_e, k, early_e, late_e + logt, logt0, lgmp_test, lgtc_e, MAH_K, early_e, late_e ) dmhdt_l, log_mah_l = _calc_halo_history( - logt, logt0, lgmp_test, lgtc_l, k, early_l, late_l + logt, logt0, lgmp_test, lgtc_l, MAH_K, early_l, late_l ) log_mah_e_bn = "log_mah_vs_time_rockstar_default_logmp_{0:.1f}_early.dat" @@ -72,14 +73,34 @@ def test_calc_halo_history_scalar_agrees_with_vmap(): tarr = np.linspace(0.1, 14, 15) logt = np.log10(tarr) logtmp = logt[-1] - logmp = 12.0 - lgtc, k, ue, ul = list(DEFAULT_MAH_PARAMS.values()) - early, late = _get_early_late(ue, ul) - dmhdt, log_mah = _calc_halo_history(logt, logtmp, logmp, lgtc, k, early, late) + lgmp, lgtc, early, late = DEFAULT_MAH_PARAMS + dmhdt, log_mah = _calc_halo_history(logt, logtmp, lgmp, lgtc, MAH_K, early, late) for i, t in enumerate(tarr): lgt_i = jnp.log10(t) - res = _calc_halo_history_scalar(lgt_i, logtmp, logmp, lgtc, k, early, late) + res = _calc_halo_history_scalar(lgt_i, logtmp, lgmp, lgtc, MAH_K, early, late) dmhdt_i, log_mah_i = res assert np.allclose(dmhdt[i], dmhdt_i) assert np.allclose(log_mah[i], log_mah_i) + + +def test_mah_singlehalo_evaluates(): + nt = 100 + tarr = np.linspace(0.1, 13.8, nt) + dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, tarr) + assert dmhdt.shape == tarr.shape + assert log_mah.shape == dmhdt.shape + assert log_mah[-1] == DEFAULT_MAH_PARAMS.logmp + + +def test_mah_halopop_evaluates(): + nt = 100 + tarr = np.linspace(0.1, 13.8, nt) + + ngals = 150 + zz = np.zeros(ngals) + mah_params_halopop = DiffmahParams(*[zz + p for p in DEFAULT_MAH_PARAMS]) + dmhdt, log_mah = mah_halopop(mah_params_halopop, tarr) + assert dmhdt.shape == (ngals, nt) + assert log_mah.shape == dmhdt.shape + assert np.allclose(log_mah[:, -1], DEFAULT_MAH_PARAMS.logmp) diff --git a/diffmah/tng_pdf_model.py b/diffmah/tng_pdf_model.py index 3bb0dc0..016a0cb 100644 --- a/diffmah/tng_pdf_model.py +++ b/diffmah/tng_pdf_model.py @@ -7,12 +7,11 @@ from jax import numpy as jnp from jax import vmap -from .individual_halo_assembly import DEFAULT_MAH_PARAMS +from .defaults import MAH_K from .utils import get_cholesky_from_params TODAY = 13.8 LGT0 = jnp.log10(TODAY) -K = DEFAULT_MAH_PARAMS["mah_k"] _LGM_X0, LGM_K = 13.0, 0.5 @@ -353,7 +352,7 @@ def _get_mah_means_and_covs( chol_ue_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ue_lgtc_late_yhi"], chol_ul_lgtc_late_ylo=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_ylo"], chol_ul_lgtc_late_yhi=DEFAULT_MAH_PDF_PARAMS["chol_ul_lgtc_late_yhi"], - k=DEFAULT_MAH_PARAMS["mah_k"], + k=MAH_K, logtmp=LGT0, ): frac_late = frac_late_forming(logmp_arr, frac_late_ylo, frac_late_yhi)