diff --git a/diffmah/__init__.py b/diffmah/__init__.py index 17f184b..dbb92e2 100644 --- a/diffmah/__init__.py +++ b/diffmah/__init__.py @@ -3,7 +3,6 @@ # flake8: noqa from ._version import __version__ - -from .monte_carlo_halo_population import mc_halo_population from .individual_halo_assembly import calc_halo_history from .monte_carlo_diffmah_hiz import mc_diffmah_params_hiz +from .monte_carlo_halo_population import mc_halo_population diff --git a/diffmah/halo_population_assembly.py b/diffmah/halo_population_assembly.py index 3246382..6d57309 100644 --- a/diffmah/halo_population_assembly.py +++ b/diffmah/halo_population_assembly.py @@ -1,13 +1,16 @@ """Calculate differentiable probabilistic history of an individual halo.""" import numpy as np -from jax import numpy as jnp from jax import jit as jjit +from jax import numpy as jnp from jax import vmap from jax.scipy.stats import multivariate_normal as jnorm -from .individual_halo_assembly import _calc_halo_history, _get_early_late -from .individual_halo_assembly import DEFAULT_MAH_PARAMS -from .rockstar_pdf_model import _get_mah_means_and_covs -from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, LGT0 + +from .individual_halo_assembly import ( + DEFAULT_MAH_PARAMS, + _calc_halo_history, + _get_early_late, +) +from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, LGT0, _get_mah_means_and_covs CLIP = -10.0 @@ -55,7 +58,7 @@ def _get_bimodal_halo_history_kern( dmhdts, log_mahs = _halo_history_integrand( logt, logtmp, logmp, lgtc_arr, k, early_arr, late_arr ) - mahs = 10 ** log_mahs + mahs = 10**log_mahs weights_early = _get_mah_weights(ue_arr, ul_arr, lgtc_arr, mu_early, cov_early) weights_late = _get_mah_weights(ue_arr, ul_arr, lgtc_arr, mu_late, cov_late) diff --git a/diffmah/individual_halo_assembly.py b/diffmah/individual_halo_assembly.py index 16f6bde..d663d73 100644 --- a/diffmah/individual_halo_assembly.py +++ b/diffmah/individual_halo_assembly.py @@ -1,8 +1,11 @@ """Model for individual halo mass assembly based on a power-law with rolling index.""" from collections import OrderedDict -from jax import numpy as jnp + +from jax import grad from jax import jit as jjit -from jax import vmap, grad +from jax import numpy as jnp +from jax import vmap + from .utils import get_1d_arrays DEFAULT_MAH_PARAMS = OrderedDict(mah_logtc=0.05, mah_k=3.5, mah_ue=2.4, mah_ul=-2.0) diff --git a/diffmah/load_mah_data.py b/diffmah/load_mah_data.py index db50012..6567f04 100644 --- a/diffmah/load_mah_data.py +++ b/diffmah/load_mah_data.py @@ -1,8 +1,9 @@ """Load the diffmah data into memory.""" -import numpy as np import os import warnings +import numpy as np + TASSO = "/Users/aphearin/work/DATA/diffmah_data/PUBLISHED_DATA" BEBOP = "/lcrc/project/halotools/diffmah_data/PUBLISHED_DATA" diff --git a/diffmah/measure_mahs.py b/diffmah/measure_mahs.py index b0b3357..66236d5 100644 --- a/diffmah/measure_mahs.py +++ b/diffmah/measure_mahs.py @@ -1,7 +1,8 @@ """Functions used to define the target data for fitting the halo population model.""" -import numpy as np import warnings +import numpy as np + def get_clean_sample_mask(log_mah_fit, logmp_sample, it_min, lim=0.01, z_cut=3): """Calculate mask to remove halos with outlier MAH behavior. @@ -64,7 +65,7 @@ def measure_target_data(mah, dmhdt, lgt, lgt_target, logmp_sample): """ mah0 = mah[:, -1].reshape(-1, 1) - mp_sample = 10 ** logmp_sample + mp_sample = 10**logmp_sample scaled_mah = mp_sample * mah / mah0 scaled_dmhdt = mp_sample * dmhdt / mah0 with warnings.catch_warnings(): @@ -84,6 +85,6 @@ def measure_target_data(mah, dmhdt, lgt, lgt_target, logmp_sample): std_dmhdt = 10 ** np.interp(lgt_target, lgt, np.log10(std_dmhdt_table)) std_log_mah = np.interp(lgt_target, lgt, std_log_mah_table) - var_dmhdt = std_dmhdt ** 2 - var_log_mah = std_log_mah ** 2 + var_dmhdt = std_dmhdt**2 + var_log_mah = std_log_mah**2 return mean_mah, mean_log_mah, var_log_mah, mean_dmhdt, var_dmhdt diff --git a/diffmah/monte_carlo_diffmah_hiz.py b/diffmah/monte_carlo_diffmah_hiz.py index aafc36c..c852826 100644 --- a/diffmah/monte_carlo_diffmah_hiz.py +++ b/diffmah/monte_carlo_diffmah_hiz.py @@ -1,9 +1,10 @@ """Generate Diffmah parameters for halos identified at higher redshift """ import numpy as np -from .monte_carlo_halo_population import mc_halo_population + from .halo_population_assembly import LGT0 from .individual_halo_assembly import calc_halo_history +from .monte_carlo_halo_population import mc_halo_population def mc_diffmah_params_hiz(ran_key, t_obs, logmh, lgt0=LGT0, npop=int(1e5)): @@ -59,7 +60,6 @@ def mc_diffmah_params_hiz(ran_key, t_obs, logmh, lgt0=LGT0, npop=int(1e5)): late_index = np.array(halopop.late_index) lgtc = np.array(halopop.lgtc) else: - lgm0_guess = _guess_logmp_z0(t_obs, logmh, lgt0, npop) tarr = np.array((t_obs, t0)) diff --git a/diffmah/monte_carlo_halo_population.py b/diffmah/monte_carlo_halo_population.py index fdd4d59..f5d874e 100644 --- a/diffmah/monte_carlo_halo_population.py +++ b/diffmah/monte_carlo_halo_population.py @@ -1,15 +1,19 @@ """ """ import typing + import numpy as np +from jax import jit as jjit from jax import numpy as jnp from jax import random as jran -from jax import jit as jjit from jax import vmap -from .rockstar_pdf_model import _get_mah_means_and_covs, DEFAULT_MAH_PDF_PARAMS -from .individual_halo_assembly import _get_early_late -from .individual_halo_assembly import DEFAULT_MAH_PARAMS, _calc_halo_history +from .individual_halo_assembly import ( + DEFAULT_MAH_PARAMS, + _calc_halo_history, + _get_early_late, +) +from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, _get_mah_means_and_covs MAH_K = DEFAULT_MAH_PARAMS["mah_k"] diff --git a/diffmah/optimize_nbody.py b/diffmah/optimize_nbody.py index 1d3b2dc..f58c0ea 100644 --- a/diffmah/optimize_nbody.py +++ b/diffmah/optimize_nbody.py @@ -1,10 +1,16 @@ """ """ from collections import OrderedDict + from jax import jit as jjit from jax import numpy as jnp -from .halo_population_assembly import _get_bimodal_halo_history -from .halo_population_assembly import UE_ARR, UL_ARR, LGTC_ARR + +from .halo_population_assembly import ( + LGTC_ARR, + UE_ARR, + UL_ARR, + _get_bimodal_halo_history, +) BOUNDS = OrderedDict( frac_late_ylo=(0.35, 0.45), diff --git a/diffmah/optimize_tng.py b/diffmah/optimize_tng.py index 1f4c277..78a0812 100644 --- a/diffmah/optimize_tng.py +++ b/diffmah/optimize_tng.py @@ -1,10 +1,16 @@ """ """ from collections import OrderedDict + from jax import jit as jjit from jax import numpy as jnp -from .halo_population_assembly import _get_bimodal_halo_history -from .halo_population_assembly import UE_ARR, UL_ARR, LGTC_ARR + +from .halo_population_assembly import ( + LGTC_ARR, + UE_ARR, + UL_ARR, + _get_bimodal_halo_history, +) BOUNDS = OrderedDict( frac_late_ylo=(0.3, 0.6), diff --git a/diffmah/tests/test_diff_mc_halos.py b/diffmah/tests/test_diff_mc_halos.py index 3a2499b..7b22029 100644 --- a/diffmah/tests/test_diff_mc_halos.py +++ b/diffmah/tests/test_diff_mc_halos.py @@ -1,14 +1,13 @@ """Unit test of the new differentiable Monte Carlo generator """ import numpy as np -from jax import random as jran from jax import jit as jjit from jax import numpy as jnp +from jax import random as jran from jax import value_and_grad -from ..monte_carlo_halo_population import mc_halo_population -from ..monte_carlo_halo_population import _mc_halo_mahs -from ..rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS +from ..monte_carlo_halo_population import _mc_halo_mahs, mc_halo_population +from ..rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS SEED = 43 diff --git a/diffmah/tests/test_fit_mah_helpers.py b/diffmah/tests/test_fit_mah_helpers.py index 53619fb..c8e95aa 100644 --- a/diffmah/tests/test_fit_mah_helpers.py +++ b/diffmah/tests/test_fit_mah_helpers.py @@ -1,6 +1,7 @@ """ """ import numpy as np + from ..fit_mah_helpers import get_target_data @@ -14,7 +15,7 @@ def test_get_target_data_no_cuts(): logt_target, log_mah_target = get_target_data( t_sim, log_mah_sim, lgm_min, dlogm_cut, t_fit_min ) - assert np.allclose(10 ** logt_target, t_sim) + assert np.allclose(10**logt_target, t_sim) assert np.allclose(log_mah_sim, log_mah_target, atol=0.01) @@ -29,5 +30,5 @@ def test_get_target_data_lgm_cut(): t_sim, log_mah_sim, lgm_min, dlogm_cut, t_fit_min ) assert logt_target.shape == log_mah_target.shape - assert np.allclose(t_sim[1:], 10 ** logt_target) + assert np.allclose(t_sim[1:], 10**logt_target) assert np.allclose(log_mah_sim[1:], log_mah_target, atol=0.01) diff --git a/diffmah/tests/test_halo_population_assembly.py b/diffmah/tests/test_halo_population_assembly.py index 3f4d438..9d4a26d 100644 --- a/diffmah/tests/test_halo_population_assembly.py +++ b/diffmah/tests/test_halo_population_assembly.py @@ -1,9 +1,15 @@ """ """ import os + import numpy as np -from ..halo_population_assembly import _get_bimodal_halo_history -from ..halo_population_assembly import UE_ARR, UL_ARR, LGTC_ARR + +from ..halo_population_assembly import ( + LGTC_ARR, + UE_ARR, + UL_ARR, + _get_bimodal_halo_history, +) from ..tng_pdf_model import DEFAULT_MAH_PDF_PARAMS as TNG_PARAMS _THIS_DRNAME = os.path.dirname(os.path.abspath(__file__)) diff --git a/diffmah/tests/test_individual_halo_assembly.py b/diffmah/tests/test_individual_halo_assembly.py index 90cf15d..d520d3a 100644 --- a/diffmah/tests/test_individual_halo_assembly.py +++ b/diffmah/tests/test_individual_halo_assembly.py @@ -1,13 +1,18 @@ """ """ import os -from jax import numpy as jnp + import numpy as np -from ..individual_halo_assembly import _calc_halo_history, DEFAULT_MAH_PARAMS -from ..individual_halo_assembly import _power_law_index_vs_logt, _get_early_late -from ..rockstar_pdf_model import _get_mean_mah_params_early, _get_mean_mah_params_late -from ..individual_halo_assembly import _calc_halo_history_scalar +from jax import numpy as jnp +from ..individual_halo_assembly import ( + DEFAULT_MAH_PARAMS, + _calc_halo_history, + _calc_halo_history_scalar, + _get_early_late, + _power_law_index_vs_logt, +) +from ..rockstar_pdf_model import _get_mean_mah_params_early, _get_mean_mah_params_late _THIS_DRNAME = os.path.dirname(os.path.abspath(__file__)) DDRN = os.path.join(_THIS_DRNAME, "testing_data") diff --git a/diffmah/tests/test_mc_halos.py b/diffmah/tests/test_mc_halos.py index def09f2..d4726c0 100644 --- a/diffmah/tests/test_mc_halos.py +++ b/diffmah/tests/test_mc_halos.py @@ -3,6 +3,7 @@ import os import numpy as np + from ..monte_carlo_halo_population import mc_halo_population from ..tng_pdf_model import DEFAULT_MAH_PDF_PARAMS as TNG_PDF_PARAMS diff --git a/diffmah/tests/test_monte_carlo_diffmah_hiz.py b/diffmah/tests/test_monte_carlo_diffmah_hiz.py index 27c41ae..7adb64e 100644 --- a/diffmah/tests/test_monte_carlo_diffmah_hiz.py +++ b/diffmah/tests/test_monte_carlo_diffmah_hiz.py @@ -1,8 +1,10 @@ """ """ import warnings -from jax import random as jran + import numpy as np +from jax import random as jran + from ..individual_halo_assembly import calc_halo_history from ..monte_carlo_diffmah_hiz import mc_diffmah_params_hiz diff --git a/diffmah/tests/test_opt_bounds.py b/diffmah/tests/test_opt_bounds.py index 94db8f9..019b48b 100644 --- a/diffmah/tests/test_opt_bounds.py +++ b/diffmah/tests/test_opt_bounds.py @@ -2,8 +2,8 @@ """ from ..optimize_nbody import BOUNDS as NBODY_BOUNDS from ..optimize_tng import BOUNDS as TNG_BOUNDS -from ..tng_pdf_model import DEFAULT_MAH_PDF_PARAMS as TNG_DEFAULTS from ..rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS as NBODY_DEFAULTS +from ..tng_pdf_model import DEFAULT_MAH_PDF_PARAMS as TNG_DEFAULTS def test_nbody_params_are_correctly_bounded():