Skip to content

Commit

Permalink
Merge pull request #147 from ArgonneCPAC/tp5
Browse files Browse the repository at this point in the history
Update diffmah and diffmahpop so that `t_peak` is a parameter
  • Loading branch information
aphearin authored Oct 24, 2024
2 parents df5e99a + 7cb4ab7 commit 94adf24
Show file tree
Hide file tree
Showing 24 changed files with 367 additions and 241 deletions.
2 changes: 1 addition & 1 deletion diffmah/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# flake8: noqa

from ._version import __version__
from .bfgs_wrapper import diffmah_fitter
from .defaults import DEFAULT_MAH_PARAMS, MAH_K, DiffmahParams
from .fitting_helpers import diffmah_fitter
from .individual_halo_assembly import calc_halo_history, mah_halopop, mah_singlehalo
from .monte_carlo_diffmah_hiz import mc_diffmah_params_hiz
from .monte_carlo_halo_population import mc_halo_population
6 changes: 3 additions & 3 deletions diffmah/bfgs_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def grad_func(p_init, loss_data):
return _res


def diffmah_fitter(val_and_grads, p_init, loss_data, nstep=200, n_warmup=1):
def bfgs_adam_fallback(val_and_grads, u_p_init, loss_data, nstep=200, n_warmup=1):
"""
Function that runs scipy's LBFGS minimizer.
If that is not successful, minimize the fit with the ADAM minimizer from JAX
Expand Down Expand Up @@ -72,7 +72,7 @@ def diffmah_fitter(val_and_grads, p_init, loss_data, nstep=200, n_warmup=1):
p_best, loss_best, fit_terminates, code_used
"""
_res = scipy_lbfgs_wrapper(val_and_grads, p_init, loss_data)
_res = scipy_lbfgs_wrapper(val_and_grads, u_p_init, loss_data)

# check if LBFGS succeeds. If yes, save those results.
# Otherwise try the Adam wrapper
Expand All @@ -84,7 +84,7 @@ def diffmah_fitter(val_and_grads, p_init, loss_data, nstep=200, n_warmup=1):
_res.append(code_used)
return _res
else:
res = jax_adam_wrapper(val_and_grads, p_init, loss_data, nstep, n_warmup)
res = jax_adam_wrapper(val_and_grads, u_p_init, loss_data, nstep, n_warmup)
p_best, loss_best, loss_arr, params_arr, fit_terminates = res
code_used = 1 # Adam
_res = [p_best, loss_best, fit_terminates, code_used]
Expand Down
75 changes: 41 additions & 34 deletions diffmah/diffmah_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
MAH_K = 3.5

DEFAULT_MAH_PDICT = OrderedDict(
logm0=12.0, logtc=0.05, early_index=2.6137643, late_index=0.12692805
logm0=12.0, logtc=0.05, early_index=2.6137643, late_index=0.12692805, t_peak=14.0
)
DiffmahParams = namedtuple("DiffmahParams", list(DEFAULT_MAH_PDICT.keys()))
DEFAULT_MAH_PARAMS = DiffmahParams(*list(DEFAULT_MAH_PDICT.values()))
Expand All @@ -21,20 +21,24 @@
DiffmahUParams = namedtuple("DiffmahUParams", _MAH_UPNAMES)

MAH_PBDICT = OrderedDict(
logm0=(0.0, 17.0), logtc=(-1.0, 1.0), early_index=(0.1, 10.0), late_index=(0.1, 5.0)
logm0=(0.0, 17.0),
logtc=(-1.0, 1.0),
early_index=(0.1, 10.0),
late_index=(0.1, 5.0),
t_peak=(0.5, 20.0),
)
MAH_PBOUNDS = DiffmahParams(*list(MAH_PBDICT.values()))


@jjit
def mah_singlehalo(mah_params, tarr, t_peak, lgt0):
dmhdt, log_mah = _diffmah_kern(mah_params, tarr, t_peak, lgt0)
def mah_singlehalo(mah_params, tarr, lgt0):
dmhdt, log_mah = _diffmah_kern(mah_params, tarr, lgt0)
return dmhdt, log_mah


@jjit
def mah_halopop(mah_params, tarr, t_peak, lgt0):
dmhdt, log_mah = _diffmah_kern_vmap(mah_params, tarr, t_peak, lgt0)
def mah_halopop(mah_params, tarr, lgt0):
dmhdt, log_mah = _diffmah_kern_vmap(mah_params, tarr, lgt0)
return dmhdt, log_mah


Expand Down Expand Up @@ -73,7 +77,7 @@ def _rolling_plaw_vs_t(t, logt0, logm0, logtc, early, late):

@jjit
def _log_mah_noq_kern(mah_params, t, logt0):
logm0, logtc, early, late = mah_params
logm0, logtc, early, late = mah_params[:4]
log_mah_noq = _rolling_plaw_vs_t(t, logt0, logm0, logtc, early, late)
return log_mah_noq

Expand All @@ -99,16 +103,17 @@ def _dmhdt_noq_kern_scalar(mah_params, t, logt0):


@jjit
def _log_mah_kern(mah_params, t, t_peak, logt0):
def _log_mah_kern(mah_params, t, logt0):
log_mah_noq = _log_mah_noq_kern(mah_params, t, logt0)
lgmhq = _log_mah_noq_kern(mah_params, t_peak, logt0)
log_mah = jnp.where(t < t_peak, log_mah_noq, lgmhq) # clip growth at t_peak
lgmhq = _log_mah_noq_kern(mah_params, mah_params.t_peak, logt0)
msk = t < mah_params.t_peak
log_mah = jnp.where(msk, log_mah_noq, lgmhq) # clip growth at t_peak
return log_mah


@jjit
def _mah_kern(mah_params, t, t_peak, logt0):
log_mah = _log_mah_kern(mah_params, t, t_peak, logt0)
def _mah_kern(mah_params, t, logt0):
log_mah = _log_mah_kern(mah_params, t, logt0)
mah = 10**log_mah
return mah

Expand All @@ -117,34 +122,34 @@ def _mah_kern(mah_params, t, t_peak, logt0):


@jjit
def _dmhdt_kern(mah_params, t, t_peak, logt0):
def _dmhdt_kern(mah_params, t, logt0):
dmhdt_noq = _dmhdt_noq_kern(mah_params, t, logt0)
dmhdt = jnp.where(t > t_peak, 0.0, dmhdt_noq)
dmhdt = jnp.where(t > mah_params.t_peak, 0.0, dmhdt_noq)
return dmhdt


@jjit
def _dmhdt_kern_scalar(mah_params, t, t_peak, logt0):
def _dmhdt_kern_scalar(mah_params, t, logt0):
dmhdt_noq = _dmhdt_noq_kern_scalar(mah_params, t, logt0)
dmhdt = jnp.where(t > t_peak, 0.0, dmhdt_noq)
dmhdt = jnp.where(t > mah_params.t_peak, 0.0, dmhdt_noq)
return dmhdt


@jjit
def _diffmah_kern(mah_params, t, t_peak, logt0):
dmhdt = _dmhdt_kern(mah_params, t, t_peak, logt0)
log_mah = _log_mah_kern(mah_params, t, t_peak, logt0)
def _diffmah_kern(mah_params, t, logt0):
dmhdt = _dmhdt_kern(mah_params, t, logt0)
log_mah = _log_mah_kern(mah_params, t, logt0)
return dmhdt, log_mah


@jjit
def _diffmah_kern_scalar(mah_params, t, t_peak, logt0):
dmhdt = _dmhdt_kern_scalar(mah_params, t, t_peak, logt0)
log_mah = _log_mah_kern(mah_params, t, t_peak, logt0)
def _diffmah_kern_scalar(mah_params, t, logt0):
dmhdt = _dmhdt_kern_scalar(mah_params, t, logt0)
log_mah = _log_mah_kern(mah_params, t, logt0)
return dmhdt, log_mah


_P = (0, None, 0, None)
_P = (0, None, None)
_diffmah_kern_vmap = jjit(vmap(_diffmah_kern, in_axes=_P))

##############################
Expand Down Expand Up @@ -186,41 +191,43 @@ def get_bounded_mah_params(u_params):
u_parr = jnp.array([getattr(u_params, u_pname) for u_pname in _MAH_UPNAMES])
logm0 = _get_bounded_diffmah_param(u_params.u_logm0, MAH_PBOUNDS.logm0)
logtc = _get_bounded_diffmah_param(u_params.u_logtc, MAH_PBOUNDS.logtc)
u_early, u_late = u_parr[2:]
t_peak = _get_bounded_diffmah_param(u_params.u_t_peak, MAH_PBOUNDS.t_peak)
u_early, u_late = u_parr[2:4]
early, late = _get_early_late(u_early, u_late)
params = DiffmahParams(logm0, logtc, early, late)
params = DiffmahParams(logm0, logtc, early, late, t_peak)
return params


@jjit
def get_unbounded_mah_params(params):
parr = jnp.array([getattr(params, pname) for pname in _MAH_PNAMES])
early, late = parr[2:]
early, late = parr[2:4]
u_early, u_late = _get_u_early_late(early, late)
u_logm0 = _get_unbounded_diffmah_param(params.logm0, MAH_PBOUNDS.logm0)
u_logtc = _get_unbounded_diffmah_param(params.logtc, MAH_PBOUNDS.logtc)
u_params = DiffmahUParams(u_logm0, u_logtc, u_early, u_late)
u_t_peak = _get_unbounded_diffmah_param(params.t_peak, MAH_PBOUNDS.t_peak)
u_params = DiffmahUParams(u_logm0, u_logtc, u_early, u_late, u_t_peak)
return u_params


DEFAULT_MAH_U_PARAMS = DiffmahUParams(*get_unbounded_mah_params(DEFAULT_MAH_PARAMS))


@jjit
def _log_mah_kern_u_params(mah_u_params, t, t_peak, logt0):
def _log_mah_kern_u_params(mah_u_params, t, logt0):
mah_params = get_bounded_mah_params(mah_u_params)
return _log_mah_kern(mah_params, t, t_peak, logt0)
return _log_mah_kern(mah_params, t, logt0)


@jjit
def _dmhdt_kern_u_params(mah_u_params, t, t_peak, logt0):
def _dmhdt_kern_u_params(mah_u_params, t, logt0):
mah_params = get_bounded_mah_params(mah_u_params)
return _dmhdt_kern(mah_params, t, t_peak, logt0)
return _dmhdt_kern(mah_params, t, logt0)


@jjit
def _diffmah_kern_u_params(mah_u_params, t, t_peak, logt0):
def _diffmah_kern_u_params(mah_u_params, t, logt0):
mah_params = get_bounded_mah_params(mah_u_params)
dmhdt = _dmhdt_kern(mah_params, t, t_peak, logt0)
log_mah = _log_mah_kern(mah_params, t, t_peak, logt0)
dmhdt = _dmhdt_kern(mah_params, t, logt0)
log_mah = _log_mah_kern(mah_params, t, logt0)
return dmhdt, log_mah
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import numpy as jnp
from jax import value_and_grad, vmap

from ...bfgs_wrapper import diffmah_fitter
from ...bfgs_wrapper import bfgs_adam_fallback
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C0_PDICT = OrderedDict(
Expand Down Expand Up @@ -78,7 +78,7 @@ def global_loss_kern(params, global_loss_data):


def fit_global_c0_model(global_loss_data, p_init=DEFAULT_LGM0POP_C0_PARAMS):
_res = diffmah_fitter(global_loss_and_grads_kern, p_init, global_loss_data)
_res = bfgs_adam_fallback(global_loss_and_grads_kern, p_init, global_loss_data)
p_best, loss_best, fit_terminates, code_used = _res
return p_best, loss_best, fit_terminates, code_used

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import numpy as jnp
from jax import value_and_grad, vmap

from ...bfgs_wrapper import diffmah_fitter
from ...bfgs_wrapper import bfgs_adam_fallback
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C0_PDICT = OrderedDict(
Expand Down Expand Up @@ -78,7 +78,7 @@ def global_loss_kern(params, global_loss_data):


def fit_global_c0_model(global_loss_data, p_init=DEFAULT_LGM0POP_C0_PARAMS):
_res = diffmah_fitter(global_loss_and_grads_kern, p_init, global_loss_data)
_res = bfgs_adam_fallback(global_loss_and_grads_kern, p_init, global_loss_data)
p_best, loss_best, fit_terminates, code_used = _res
return p_best, loss_best, fit_terminates, code_used

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import numpy as jnp
from jax import value_and_grad, vmap

from ...bfgs_wrapper import diffmah_fitter
from ...bfgs_wrapper import bfgs_adam_fallback
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C1_PDICT = OrderedDict(
Expand Down Expand Up @@ -89,7 +89,7 @@ def global_loss_kern(params, global_loss_data):


def fit_global_c1_model(global_loss_data, p_init=DEFAULT_LGM0POP_C1_PARAMS):
_res = diffmah_fitter(global_loss_and_grads_kern, p_init, global_loss_data)
_res = bfgs_adam_fallback(global_loss_and_grads_kern, p_init, global_loss_data)
p_best, loss_best, fit_terminates, code_used = _res
return p_best, loss_best, fit_terminates, code_used

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import numpy as jnp
from jax import value_and_grad, vmap

from ...bfgs_wrapper import diffmah_fitter
from ...bfgs_wrapper import bfgs_adam_fallback
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C1_PDICT = OrderedDict(
Expand Down Expand Up @@ -89,7 +89,7 @@ def global_loss_kern(params, global_loss_data):


def fit_global_c1_model(global_loss_data, p_init=DEFAULT_LGM0POP_C1_PARAMS):
_res = diffmah_fitter(global_loss_and_grads_kern, p_init, global_loss_data)
_res = bfgs_adam_fallback(global_loss_and_grads_kern, p_init, global_loss_data)
p_best, loss_best, fit_terminates, code_used = _res
return p_best, loss_best, fit_terminates, code_used

Expand Down
4 changes: 2 additions & 2 deletions diffmah/diffmahpop_kernels/bimod_logm0_sats/logm0_c0_early.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import numpy as jnp
from jax import value_and_grad, vmap

from ...bfgs_wrapper import diffmah_fitter
from ...bfgs_wrapper import bfgs_adam_fallback
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C0_PDICT = OrderedDict(
Expand Down Expand Up @@ -81,7 +81,7 @@ def global_loss_kern(params, global_loss_data):


def fit_global_c0_model(global_loss_data, p_init=DEFAULT_LGM0POP_C0_PARAMS):
_res = diffmah_fitter(global_loss_and_grads_kern, p_init, global_loss_data)
_res = bfgs_adam_fallback(global_loss_and_grads_kern, p_init, global_loss_data)
p_best, loss_best, fit_terminates, code_used = _res
return p_best, loss_best, fit_terminates, code_used

Expand Down
4 changes: 2 additions & 2 deletions diffmah/diffmahpop_kernels/bimod_logm0_sats/logm0_c0_late.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import numpy as jnp
from jax import value_and_grad, vmap

from ...bfgs_wrapper import diffmah_fitter
from ...bfgs_wrapper import bfgs_adam_fallback
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C0_PDICT = OrderedDict(
Expand Down Expand Up @@ -81,7 +81,7 @@ def global_loss_kern(params, global_loss_data):


def fit_global_c0_model(global_loss_data, p_init=DEFAULT_LGM0POP_C0_PARAMS):
_res = diffmah_fitter(global_loss_and_grads_kern, p_init, global_loss_data)
_res = bfgs_adam_fallback(global_loss_and_grads_kern, p_init, global_loss_data)
p_best, loss_best, fit_terminates, code_used = _res
return p_best, loss_best, fit_terminates, code_used

Expand Down
4 changes: 2 additions & 2 deletions diffmah/diffmahpop_kernels/bimod_logm0_sats/logm0_c1_early.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import numpy as jnp
from jax import value_and_grad, vmap

from ...bfgs_wrapper import diffmah_fitter
from ...bfgs_wrapper import bfgs_adam_fallback
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C1_PDICT = OrderedDict(
Expand Down Expand Up @@ -89,7 +89,7 @@ def global_loss_kern(params, global_loss_data):


def fit_global_c1_model(global_loss_data, p_init=DEFAULT_LGM0POP_C1_PARAMS):
_res = diffmah_fitter(global_loss_and_grads_kern, p_init, global_loss_data)
_res = bfgs_adam_fallback(global_loss_and_grads_kern, p_init, global_loss_data)
p_best, loss_best, fit_terminates, code_used = _res
return p_best, loss_best, fit_terminates, code_used

Expand Down
4 changes: 2 additions & 2 deletions diffmah/diffmahpop_kernels/bimod_logm0_sats/logm0_c1_late.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import numpy as jnp
from jax import value_and_grad, vmap

from ...bfgs_wrapper import diffmah_fitter
from ...bfgs_wrapper import bfgs_adam_fallback
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C1_PDICT = OrderedDict(
Expand Down Expand Up @@ -89,7 +89,7 @@ def global_loss_kern(params, global_loss_data):


def fit_global_c1_model(global_loss_data, p_init=DEFAULT_LGM0POP_C1_PARAMS):
_res = diffmah_fitter(global_loss_and_grads_kern, p_init, global_loss_data)
_res = bfgs_adam_fallback(global_loss_and_grads_kern, p_init, global_loss_data)
p_best, loss_best, fit_terminates, code_used = _res
return p_best, loss_best, fit_terminates, code_used

Expand Down
Loading

0 comments on commit 94adf24

Please sign in to comment.