Skip to content

Commit

Permalink
Merge pull request #120 from ArgonneCPAC/dev_tq
Browse files Browse the repository at this point in the history
New model of diffmah that includes tpeak feature
  • Loading branch information
aphearin authored Jun 20, 2024
2 parents 1f1ec42 + 8419fde commit 9287d28
Show file tree
Hide file tree
Showing 9 changed files with 674 additions and 2 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/monthly-warning-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ on:
schedule:
# Runs "First of every month at 3:15am Central"
- cron: '15 8 1 * *'
push:
branches:
- main
pull_request: null

jobs:
tests:
Expand All @@ -18,7 +22,7 @@ jobs:

- uses: conda-incubator/setup-miniconda@v2
with:
python-version: 3.9
python-version: 3.11
channels: conda-forge,defaults
channel-priority: strict
show-channel-urls: true
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:

- uses: conda-incubator/setup-miniconda@v2
with:
python-version: 3.9
python-version: 3.11
channels: conda-forge,defaults
channel-priority: strict
show-channel-urls: true
Expand Down
2 changes: 2 additions & 0 deletions diffmah/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""
"""

# flake8: noqa

from ._version import __version__
from .bfgs_wrapper import diffmah_fitter
from .defaults import DEFAULT_MAH_PARAMS, MAH_K, DiffmahParams
from .individual_halo_assembly import calc_halo_history, mah_halopop, mah_singlehalo
from .monte_carlo_diffmah_hiz import mc_diffmah_params_hiz
Expand Down
91 changes: 91 additions & 0 deletions diffmah/bfgs_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Functions in this script fit MAH using an alternating wrapper.
The main minimizer used is scipy's LBFGS algo.
If this does not work (success=False), then it tries minimizing with JAX's ADAM algo.
The main function is minimize_alternate_wrappers.
"""

import numpy as np
from scipy.optimize import minimize

from .utils import jax_adam_wrapper


def scipy_lbfgs_wrapper(val_and_grads, p_init, loss_data):
"""
Function that runs scipy's LBFGS minimizer.
Args:
val_and_grads: function that returns the loss function along with the grads.
For LBFGS, one does not need to use grads.
p_init: array of initial values for parameters
loss_data: Sequence of floats and arrays storing
whatever data is needed to compute loss_func(params_init, loss_data)
Returns:
_res: list of best fit parameters, best fit loss,
and a boolean whether the fit was successful or not.
"""

# Define the loss and grad functions from value_and_grads
# scipy wants them separated
def loss_func(p_init, loss_data):
return float(val_and_grads(p_init, loss_data)[0])

def grad_func(p_init, loss_data):
return np.array(val_and_grads(p_init, loss_data)[1]).astype(float)

# run scipy's LBFGS minimizer
result = minimize(
loss_func, p_init, method="L-BFGS-B", jac=grad_func, args=(loss_data,)
)
_res = [result.x, result.fun, result.success]
return _res


def diffmah_fitter(val_and_grads, p_init, loss_data, nstep=200, n_warmup=1):
"""
Function that runs scipy's LBFGS minimizer.
If that is not successful, minimize the fit with the ADAM minimizer from JAX
Parameters
-----------
val_and_grads: func
function returns the loss function along with the grads
u_p_init: array
initial values for unbounded parameters
loss_data: Sequence of floats and arrays storing
whatever data is needed to compute loss_func(params_init, loss_data)
nstep: int, optional
Number of steps that the ADAM wrapper needs to run for (default = 200)
n_warmup: int, optional
Number of warmup steps to use (default = 1)
Returns
-------
_res: list
p_best, loss_best, fit_terminates, code_used
"""
_res = scipy_lbfgs_wrapper(val_and_grads, p_init, loss_data)

# check if LBFGS succeeds. If yes, save those results.
# Otherwise try the Adam wrapper
fit_terminates = _res[-1]
loss_bfgs = _res[1]
bfgs_succeeds = fit_terminates & (np.isfinite(loss_bfgs)) & (loss_bfgs > 0)
if bfgs_succeeds:
code_used = 0 # BFGS
_res.append(code_used)
return _res
else:
res = jax_adam_wrapper(val_and_grads, p_init, loss_data, nstep, n_warmup)
p_best, loss_best, loss_arr, params_arr, fit_terminates = res
code_used = 1 # Adam
_res = [p_best, loss_best, fit_terminates, code_used]
return _res
197 changes: 197 additions & 0 deletions diffmah/diffmah_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
"""

from collections import OrderedDict, namedtuple

from jax import grad
from jax import jit as jjit
from jax import lax, nn
from jax import numpy as jnp
from jax import vmap

MAH_K = 3.5

DEFAULT_MAH_PDICT = OrderedDict(
logm0=12.0, logtc=0.05, early_index=2.6137643, late_index=0.12692805
)
DiffmahParams = namedtuple("DiffmahParams", list(DEFAULT_MAH_PDICT.keys()))
DEFAULT_MAH_PARAMS = DiffmahParams(*list(DEFAULT_MAH_PDICT.values()))
_MAH_PNAMES = list(DEFAULT_MAH_PDICT.keys())
_MAH_UPNAMES = ["u_" + key for key in _MAH_PNAMES]
DiffmahUParams = namedtuple("DiffmahUParams", _MAH_UPNAMES)

MAH_PBDICT = OrderedDict(
logm0=(0.0, 17.0), logtc=(-1.0, 1.0), early_index=(0.1, 10.0), late_index=(0.1, 5.0)
)
MAH_PBOUNDS = DiffmahParams(*list(MAH_PBDICT.values()))


@jjit
def _sigmoid(x, x0, k, ylo, yhi):
height_diff = yhi - ylo
return ylo + height_diff * nn.sigmoid(k * (x - x0))


@jjit
def _inverse_sigmoid(y, x0, k, ylo, yhi):
lnarg = (yhi - ylo) / (y - ylo) - 1
return x0 - lax.log(lnarg) / k


@jjit
def _power_law_index_vs_logt(logt, logtc, early, late):
rolling_index = _sigmoid(logt, logtc, MAH_K, early, late)
return rolling_index


@jjit
def _rolling_plaw_vs_logt(logt, logt0, logm0, logtc, early, late):
"""Kernel of the rolling power-law between halo mass and time."""
rolling_index = _power_law_index_vs_logt(logt, logtc, early, late)
log_mah = rolling_index * (logt - logt0) + logm0
return log_mah


@jjit
def _rolling_plaw_vs_t(t, logt0, logm0, logtc, early, late):
"""Convenience wrapper used to calculate d/dt of _rolling_plaw_vs_logt"""
logt = jnp.log10(t)
return _rolling_plaw_vs_logt(logt, logt0, logm0, logtc, early, late)


@jjit
def _log_mah_noq_kern(mah_params, t, logt0):
logm0, logtc, early, late = mah_params
log_mah_noq = _rolling_plaw_vs_t(t, logt0, logm0, logtc, early, late)
return log_mah_noq


@jjit
def _mah_noq_kern(mah_params, t, logt0):
log_mah_noq = _log_mah_noq_kern(mah_params, t, logt0)
mah_noq = 10**log_mah_noq
return mah_noq


_dmhdt_noq_grad_kern_scalar = jjit(grad(_mah_noq_kern, argnums=1))
_dmhdt_noq_grad_kern = jjit(vmap(_dmhdt_noq_grad_kern_scalar, in_axes=(None, 0, None)))


@jjit
def _dmhdt_noq_kern_scalar(mah_params, t, logt0):
dmhdt = _dmhdt_noq_grad_kern_scalar(mah_params, t, logt0) / 1e9
return dmhdt


_dmhdt_noq_kern = jjit(vmap(_dmhdt_noq_kern_scalar, in_axes=(None, 0, None)))


@jjit
def _log_mah_kern(mah_params, t, t_peak, logt0):
log_mah_noq = _log_mah_noq_kern(mah_params, t, logt0)
lgmhq = _log_mah_noq_kern(mah_params, t_peak, logt0)
log_mah = jnp.where(t < t_peak, log_mah_noq, lgmhq) # clip growth at t_peak
return log_mah


@jjit
def _mah_kern(mah_params, t, t_peak, logt0):
log_mah = _log_mah_kern(mah_params, t, t_peak, logt0)
mah = 10**log_mah
return mah


_dmhdt_grad_kern_unscaled = jjit(grad(_mah_kern, argnums=1))


@jjit
def _dmhdt_kern(mah_params, t, t_peak, logt0):
dmhdt_noq = _dmhdt_noq_kern(mah_params, t, logt0)
dmhdt = jnp.where(t > t_peak, 0.0, dmhdt_noq)
return dmhdt


@jjit
def _diffmah_kern(mah_params, t, t_peak, logt0):
dmhdt = _dmhdt_kern(mah_params, t, t_peak, logt0)
log_mah = _log_mah_kern(mah_params, t, t_peak, logt0)
return dmhdt, log_mah


##############################
# Unbounded parameter behavior

BOUNDING_K = 0.1


@jjit
def _get_bounded_diffmah_param(u_param, bound):
lo, hi = bound
mid = 0.5 * (lo + hi)
return _sigmoid(u_param, mid, BOUNDING_K, lo, hi)


@jjit
def _get_unbounded_diffmah_param(param, bound):
lo, hi = bound
mid = 0.5 * (lo + hi)
return _inverse_sigmoid(param, mid, BOUNDING_K, lo, hi)


@jjit
def _get_early_late(u_early, u_late):
late = _get_bounded_diffmah_param(u_late, MAH_PBOUNDS.late_index)
early = _sigmoid(u_early, 0.0, BOUNDING_K, late, MAH_PBOUNDS.early_index[1])
return early, late


@jjit
def _get_u_early_late(early, late):
u_late = _get_unbounded_diffmah_param(late, MAH_PBOUNDS.late_index)
u_early = _inverse_sigmoid(early, 0.0, BOUNDING_K, late, MAH_PBOUNDS.early_index[1])
return u_early, u_late


@jjit
def get_bounded_mah_params(u_params):
u_parr = jnp.array([getattr(u_params, u_pname) for u_pname in _MAH_UPNAMES])
logm0 = _get_bounded_diffmah_param(u_params.u_logm0, MAH_PBOUNDS.logm0)
logtc = _get_bounded_diffmah_param(u_params.u_logtc, MAH_PBOUNDS.logtc)
u_early, u_late = u_parr[2:]
early, late = _get_early_late(u_early, u_late)
params = DiffmahParams(logm0, logtc, early, late)
return params


@jjit
def get_unbounded_mah_params(params):
parr = jnp.array([getattr(params, pname) for pname in _MAH_PNAMES])
early, late = parr[2:]
u_early, u_late = _get_u_early_late(early, late)
u_logm0 = _get_unbounded_diffmah_param(params.logm0, MAH_PBOUNDS.logm0)
u_logtc = _get_unbounded_diffmah_param(params.logtc, MAH_PBOUNDS.logtc)
u_params = DiffmahUParams(u_logm0, u_logtc, u_early, u_late)
return u_params


DEFAULT_MAH_U_PARAMS = DiffmahUParams(*get_unbounded_mah_params(DEFAULT_MAH_PARAMS))


@jjit
def _log_mah_kern_u_params(mah_u_params, t, t_peak, logt0):
mah_params = get_bounded_mah_params(mah_u_params)
return _log_mah_kern(mah_params, t, t_peak, logt0)


@jjit
def _dmhdt_kern_u_params(mah_u_params, t, t_peak, logt0):
mah_params = get_bounded_mah_params(mah_u_params)
return _dmhdt_kern(mah_params, t, t_peak, logt0)


@jjit
def _diffmah_kern_u_params(mah_u_params, t, t_peak, logt0):
mah_params = get_bounded_mah_params(mah_u_params)
dmhdt = _dmhdt_kern(mah_params, t, t_peak, logt0)
log_mah = _log_mah_kern(mah_params, t, t_peak, logt0)
return dmhdt, log_mah
Loading

0 comments on commit 9287d28

Please sign in to comment.