-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #120 from ArgonneCPAC/dev_tq
New model of diffmah that includes tpeak feature
- Loading branch information
Showing
9 changed files
with
674 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
""" | ||
Functions in this script fit MAH using an alternating wrapper. | ||
The main minimizer used is scipy's LBFGS algo. | ||
If this does not work (success=False), then it tries minimizing with JAX's ADAM algo. | ||
The main function is minimize_alternate_wrappers. | ||
""" | ||
|
||
import numpy as np | ||
from scipy.optimize import minimize | ||
|
||
from .utils import jax_adam_wrapper | ||
|
||
|
||
def scipy_lbfgs_wrapper(val_and_grads, p_init, loss_data): | ||
""" | ||
Function that runs scipy's LBFGS minimizer. | ||
Args: | ||
val_and_grads: function that returns the loss function along with the grads. | ||
For LBFGS, one does not need to use grads. | ||
p_init: array of initial values for parameters | ||
loss_data: Sequence of floats and arrays storing | ||
whatever data is needed to compute loss_func(params_init, loss_data) | ||
Returns: | ||
_res: list of best fit parameters, best fit loss, | ||
and a boolean whether the fit was successful or not. | ||
""" | ||
|
||
# Define the loss and grad functions from value_and_grads | ||
# scipy wants them separated | ||
def loss_func(p_init, loss_data): | ||
return float(val_and_grads(p_init, loss_data)[0]) | ||
|
||
def grad_func(p_init, loss_data): | ||
return np.array(val_and_grads(p_init, loss_data)[1]).astype(float) | ||
|
||
# run scipy's LBFGS minimizer | ||
result = minimize( | ||
loss_func, p_init, method="L-BFGS-B", jac=grad_func, args=(loss_data,) | ||
) | ||
_res = [result.x, result.fun, result.success] | ||
return _res | ||
|
||
|
||
def diffmah_fitter(val_and_grads, p_init, loss_data, nstep=200, n_warmup=1): | ||
""" | ||
Function that runs scipy's LBFGS minimizer. | ||
If that is not successful, minimize the fit with the ADAM minimizer from JAX | ||
Parameters | ||
----------- | ||
val_and_grads: func | ||
function returns the loss function along with the grads | ||
u_p_init: array | ||
initial values for unbounded parameters | ||
loss_data: Sequence of floats and arrays storing | ||
whatever data is needed to compute loss_func(params_init, loss_data) | ||
nstep: int, optional | ||
Number of steps that the ADAM wrapper needs to run for (default = 200) | ||
n_warmup: int, optional | ||
Number of warmup steps to use (default = 1) | ||
Returns | ||
------- | ||
_res: list | ||
p_best, loss_best, fit_terminates, code_used | ||
""" | ||
_res = scipy_lbfgs_wrapper(val_and_grads, p_init, loss_data) | ||
|
||
# check if LBFGS succeeds. If yes, save those results. | ||
# Otherwise try the Adam wrapper | ||
fit_terminates = _res[-1] | ||
loss_bfgs = _res[1] | ||
bfgs_succeeds = fit_terminates & (np.isfinite(loss_bfgs)) & (loss_bfgs > 0) | ||
if bfgs_succeeds: | ||
code_used = 0 # BFGS | ||
_res.append(code_used) | ||
return _res | ||
else: | ||
res = jax_adam_wrapper(val_and_grads, p_init, loss_data, nstep, n_warmup) | ||
p_best, loss_best, loss_arr, params_arr, fit_terminates = res | ||
code_used = 1 # Adam | ||
_res = [p_best, loss_best, fit_terminates, code_used] | ||
return _res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
""" | ||
""" | ||
|
||
from collections import OrderedDict, namedtuple | ||
|
||
from jax import grad | ||
from jax import jit as jjit | ||
from jax import lax, nn | ||
from jax import numpy as jnp | ||
from jax import vmap | ||
|
||
MAH_K = 3.5 | ||
|
||
DEFAULT_MAH_PDICT = OrderedDict( | ||
logm0=12.0, logtc=0.05, early_index=2.6137643, late_index=0.12692805 | ||
) | ||
DiffmahParams = namedtuple("DiffmahParams", list(DEFAULT_MAH_PDICT.keys())) | ||
DEFAULT_MAH_PARAMS = DiffmahParams(*list(DEFAULT_MAH_PDICT.values())) | ||
_MAH_PNAMES = list(DEFAULT_MAH_PDICT.keys()) | ||
_MAH_UPNAMES = ["u_" + key for key in _MAH_PNAMES] | ||
DiffmahUParams = namedtuple("DiffmahUParams", _MAH_UPNAMES) | ||
|
||
MAH_PBDICT = OrderedDict( | ||
logm0=(0.0, 17.0), logtc=(-1.0, 1.0), early_index=(0.1, 10.0), late_index=(0.1, 5.0) | ||
) | ||
MAH_PBOUNDS = DiffmahParams(*list(MAH_PBDICT.values())) | ||
|
||
|
||
@jjit | ||
def _sigmoid(x, x0, k, ylo, yhi): | ||
height_diff = yhi - ylo | ||
return ylo + height_diff * nn.sigmoid(k * (x - x0)) | ||
|
||
|
||
@jjit | ||
def _inverse_sigmoid(y, x0, k, ylo, yhi): | ||
lnarg = (yhi - ylo) / (y - ylo) - 1 | ||
return x0 - lax.log(lnarg) / k | ||
|
||
|
||
@jjit | ||
def _power_law_index_vs_logt(logt, logtc, early, late): | ||
rolling_index = _sigmoid(logt, logtc, MAH_K, early, late) | ||
return rolling_index | ||
|
||
|
||
@jjit | ||
def _rolling_plaw_vs_logt(logt, logt0, logm0, logtc, early, late): | ||
"""Kernel of the rolling power-law between halo mass and time.""" | ||
rolling_index = _power_law_index_vs_logt(logt, logtc, early, late) | ||
log_mah = rolling_index * (logt - logt0) + logm0 | ||
return log_mah | ||
|
||
|
||
@jjit | ||
def _rolling_plaw_vs_t(t, logt0, logm0, logtc, early, late): | ||
"""Convenience wrapper used to calculate d/dt of _rolling_plaw_vs_logt""" | ||
logt = jnp.log10(t) | ||
return _rolling_plaw_vs_logt(logt, logt0, logm0, logtc, early, late) | ||
|
||
|
||
@jjit | ||
def _log_mah_noq_kern(mah_params, t, logt0): | ||
logm0, logtc, early, late = mah_params | ||
log_mah_noq = _rolling_plaw_vs_t(t, logt0, logm0, logtc, early, late) | ||
return log_mah_noq | ||
|
||
|
||
@jjit | ||
def _mah_noq_kern(mah_params, t, logt0): | ||
log_mah_noq = _log_mah_noq_kern(mah_params, t, logt0) | ||
mah_noq = 10**log_mah_noq | ||
return mah_noq | ||
|
||
|
||
_dmhdt_noq_grad_kern_scalar = jjit(grad(_mah_noq_kern, argnums=1)) | ||
_dmhdt_noq_grad_kern = jjit(vmap(_dmhdt_noq_grad_kern_scalar, in_axes=(None, 0, None))) | ||
|
||
|
||
@jjit | ||
def _dmhdt_noq_kern_scalar(mah_params, t, logt0): | ||
dmhdt = _dmhdt_noq_grad_kern_scalar(mah_params, t, logt0) / 1e9 | ||
return dmhdt | ||
|
||
|
||
_dmhdt_noq_kern = jjit(vmap(_dmhdt_noq_kern_scalar, in_axes=(None, 0, None))) | ||
|
||
|
||
@jjit | ||
def _log_mah_kern(mah_params, t, t_peak, logt0): | ||
log_mah_noq = _log_mah_noq_kern(mah_params, t, logt0) | ||
lgmhq = _log_mah_noq_kern(mah_params, t_peak, logt0) | ||
log_mah = jnp.where(t < t_peak, log_mah_noq, lgmhq) # clip growth at t_peak | ||
return log_mah | ||
|
||
|
||
@jjit | ||
def _mah_kern(mah_params, t, t_peak, logt0): | ||
log_mah = _log_mah_kern(mah_params, t, t_peak, logt0) | ||
mah = 10**log_mah | ||
return mah | ||
|
||
|
||
_dmhdt_grad_kern_unscaled = jjit(grad(_mah_kern, argnums=1)) | ||
|
||
|
||
@jjit | ||
def _dmhdt_kern(mah_params, t, t_peak, logt0): | ||
dmhdt_noq = _dmhdt_noq_kern(mah_params, t, logt0) | ||
dmhdt = jnp.where(t > t_peak, 0.0, dmhdt_noq) | ||
return dmhdt | ||
|
||
|
||
@jjit | ||
def _diffmah_kern(mah_params, t, t_peak, logt0): | ||
dmhdt = _dmhdt_kern(mah_params, t, t_peak, logt0) | ||
log_mah = _log_mah_kern(mah_params, t, t_peak, logt0) | ||
return dmhdt, log_mah | ||
|
||
|
||
############################## | ||
# Unbounded parameter behavior | ||
|
||
BOUNDING_K = 0.1 | ||
|
||
|
||
@jjit | ||
def _get_bounded_diffmah_param(u_param, bound): | ||
lo, hi = bound | ||
mid = 0.5 * (lo + hi) | ||
return _sigmoid(u_param, mid, BOUNDING_K, lo, hi) | ||
|
||
|
||
@jjit | ||
def _get_unbounded_diffmah_param(param, bound): | ||
lo, hi = bound | ||
mid = 0.5 * (lo + hi) | ||
return _inverse_sigmoid(param, mid, BOUNDING_K, lo, hi) | ||
|
||
|
||
@jjit | ||
def _get_early_late(u_early, u_late): | ||
late = _get_bounded_diffmah_param(u_late, MAH_PBOUNDS.late_index) | ||
early = _sigmoid(u_early, 0.0, BOUNDING_K, late, MAH_PBOUNDS.early_index[1]) | ||
return early, late | ||
|
||
|
||
@jjit | ||
def _get_u_early_late(early, late): | ||
u_late = _get_unbounded_diffmah_param(late, MAH_PBOUNDS.late_index) | ||
u_early = _inverse_sigmoid(early, 0.0, BOUNDING_K, late, MAH_PBOUNDS.early_index[1]) | ||
return u_early, u_late | ||
|
||
|
||
@jjit | ||
def get_bounded_mah_params(u_params): | ||
u_parr = jnp.array([getattr(u_params, u_pname) for u_pname in _MAH_UPNAMES]) | ||
logm0 = _get_bounded_diffmah_param(u_params.u_logm0, MAH_PBOUNDS.logm0) | ||
logtc = _get_bounded_diffmah_param(u_params.u_logtc, MAH_PBOUNDS.logtc) | ||
u_early, u_late = u_parr[2:] | ||
early, late = _get_early_late(u_early, u_late) | ||
params = DiffmahParams(logm0, logtc, early, late) | ||
return params | ||
|
||
|
||
@jjit | ||
def get_unbounded_mah_params(params): | ||
parr = jnp.array([getattr(params, pname) for pname in _MAH_PNAMES]) | ||
early, late = parr[2:] | ||
u_early, u_late = _get_u_early_late(early, late) | ||
u_logm0 = _get_unbounded_diffmah_param(params.logm0, MAH_PBOUNDS.logm0) | ||
u_logtc = _get_unbounded_diffmah_param(params.logtc, MAH_PBOUNDS.logtc) | ||
u_params = DiffmahUParams(u_logm0, u_logtc, u_early, u_late) | ||
return u_params | ||
|
||
|
||
DEFAULT_MAH_U_PARAMS = DiffmahUParams(*get_unbounded_mah_params(DEFAULT_MAH_PARAMS)) | ||
|
||
|
||
@jjit | ||
def _log_mah_kern_u_params(mah_u_params, t, t_peak, logt0): | ||
mah_params = get_bounded_mah_params(mah_u_params) | ||
return _log_mah_kern(mah_params, t, t_peak, logt0) | ||
|
||
|
||
@jjit | ||
def _dmhdt_kern_u_params(mah_u_params, t, t_peak, logt0): | ||
mah_params = get_bounded_mah_params(mah_u_params) | ||
return _dmhdt_kern(mah_params, t, t_peak, logt0) | ||
|
||
|
||
@jjit | ||
def _diffmah_kern_u_params(mah_u_params, t, t_peak, logt0): | ||
mah_params = get_bounded_mah_params(mah_u_params) | ||
dmhdt = _dmhdt_kern(mah_params, t, t_peak, logt0) | ||
log_mah = _log_mah_kern(mah_params, t, t_peak, logt0) | ||
return dmhdt, log_mah |
Oops, something went wrong.