Skip to content

Commit

Permalink
Add test of mah_singlehalo and mah_halopop
Browse files Browse the repository at this point in the history
  • Loading branch information
aphearin committed Jan 15, 2024
1 parent ba66137 commit 9fa0bb4
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion diffmah/tests/test_individual_halo_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import numpy as np
from jax import numpy as jnp

from ..defaults import DEFAULT_MAH_PARAMS, MAH_K
from ..defaults import DEFAULT_MAH_PARAMS, MAH_K, DiffmahParams
from ..individual_halo_assembly import (
_calc_halo_history,
_calc_halo_history_scalar,
_get_early_late,
_power_law_index_vs_logt,
mah_halopop,
mah_singlehalo,
)
from ..rockstar_pdf_model import _get_mean_mah_params_early, _get_mean_mah_params_late

Expand Down Expand Up @@ -80,3 +82,25 @@ def test_calc_halo_history_scalar_agrees_with_vmap():
dmhdt_i, log_mah_i = res
assert np.allclose(dmhdt[i], dmhdt_i)
assert np.allclose(log_mah[i], log_mah_i)


def test_mah_singlehalo_evaluates():
nt = 100
tarr = np.linspace(0.1, 13.8, nt)
dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, tarr)
assert dmhdt.shape == tarr.shape
assert log_mah.shape == dmhdt.shape
assert log_mah[-1] == DEFAULT_MAH_PARAMS.logmp


def test_mah_halopop_evaluates():
nt = 100
tarr = np.linspace(0.1, 13.8, nt)

ngals = 150
zz = np.zeros(ngals)
mah_params_halopop = DiffmahParams(*[zz + p for p in DEFAULT_MAH_PARAMS])
dmhdt, log_mah = mah_halopop(mah_params_halopop, tarr)
assert dmhdt.shape == (ngals, nt)
assert log_mah.shape == dmhdt.shape
assert np.allclose(log_mah[:, -1], DEFAULT_MAH_PARAMS.logmp)

0 comments on commit 9fa0bb4

Please sign in to comment.