Skip to content

Commit

Permalink
Add calc_bursty_age_weights function to diffburst.py
Browse files Browse the repository at this point in the history
  • Loading branch information
aphearin committed Jan 9, 2024
1 parent ef9805c commit 4702648
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 29 deletions.
2 changes: 1 addition & 1 deletion dsps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
"""
from ._version import __version__
from .data_loaders import load_ssp_templates, load_transmission_curve
from .data_loaders import SSPData, load_ssp_templates, load_transmission_curve
from .photometry import *
from .sed import *
from .utils import cumulative_mstar_formed
120 changes: 102 additions & 18 deletions dsps/sfh/diffburst.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,99 @@
"""
"""
import typing
from jax import numpy as jnp

from jax import jit as jjit
from ..utils import triweight_gaussian, _sigmoid, _inverse_sigmoid
from jax import numpy as jnp

from ..utils import _inverse_sigmoid, _sigmoid, triweight_gaussian

LGYR_PEAK_MIN = 5.0
LGAGE_MAX = 9.0
DLGAGE_MIN = 1.0
LGAGE_K = 0.1
LGFBURST_MIN = -8.0
LGFBURST_MAX = -1.0
LGFB_X0, LGFB_K = -2, 0.1
DEFAULT_LGFBURST = -4.0


class BurstParams(typing.NamedTuple):
lgfburst: jnp.float32
lgyr_peak: jnp.float32
lgyr_max: jnp.float32


class BurstUParams(typing.NamedTuple):
u_lgfburst: jnp.float32
u_lgyr_peak: jnp.float32
u_lgyr_max: jnp.float32


DEFAULT_PARAMS = BurstParams(5.5, 7.0)
DEFAULT_PARAMS = BurstParams(-3.0, 5.5, 7.0)


@jjit
def calc_bursty_age_weights(burst_params, smooth_age_weights, ssp_lg_age_gyr):
"""Calculate the distribution of stellar ages of a smooth+bursty population
Parameters
----------
burst_params : namedtuple
burst_params = (lgfburst, lgyr_peak, lgyr_max)
smooth_age_weights : ndarray, shape (n_age, )
Array storing the distribution of stellar ages, P(τ), from a smooth SFH
ssp_lg_age_gyr : ndarray of shape (n_age, )
Base-10 log of stellar age in Gyr
The namedtuple dsps.SPSData.ssp_lg_age_gyr stores
a grid of stellar ages in these units
Returns
-------
age_weights : ndarray, shape (n_age, )
P(τ) after adding a fractional contribution from the bursting population
"""
burst_params = BurstParams(*burst_params)

ssp_lg_age_yr = ssp_lg_age_gyr + 9
burst_weights = _age_weights_from_params(ssp_lg_age_yr, burst_params)

fb = 10**burst_params.lgfburst
age_weights = fb * burst_weights + (1 - fb) * smooth_age_weights

return age_weights


@jjit
def calc_bursty_age_weights_from_u_params(
u_burst_params, smooth_age_weights, ssp_lg_age_gyr
):
"""Calculate the distribution of stellar ages of a smooth+bursty population
when passed unbounded parameters
Parameters
----------
u_burst_params : namedtuple
burst_params = (u_lgfburst, u_lgyr_peak, u_lgyr_max)
smooth_age_weights : ndarray, shape (n_age, )
Array storing the distribution of stellar ages, P(τ), from a smooth SFH
ssp_lg_age_gyr : ndarray of shape (n_age, )
Base-10 log of stellar age in Gyr
The namedtuple dsps.SPSData.ssp_lg_age_gyr stores
a grid of stellar ages in these units
Returns
-------
age_weights : ndarray, shape (n_age, )
P(τ) after adding a fractional contribution from the bursting population
"""
burst_params = _get_params_from_u_params(u_burst_params)
return calc_bursty_age_weights(burst_params, smooth_age_weights, ssp_lg_age_gyr)


@jjit
Expand All @@ -32,14 +104,13 @@ def _zero_safe_normalize(x):


@jjit
def _age_weights_from_params(lgyr, params):
lgyr_peak, lgyr_max = params
def _age_weights_from_params(lgyr, burst_params):
burst_params = BurstParams(*burst_params)
dlgyr_support = burst_params.lgyr_max - burst_params.lgyr_peak
lgyr_min = burst_params.lgyr_peak - dlgyr_support
twx0 = 0.5 * (lgyr_min + burst_params.lgyr_max)

dlgyr_support = lgyr_max - lgyr_peak
lgyr_min = lgyr_peak - dlgyr_support
twx0 = 0.5 * (lgyr_min + lgyr_max)

dlgyr = lgyr_max - lgyr_min
dlgyr = burst_params.lgyr_max - lgyr_min
twh = dlgyr / 6

tw_gauss = triweight_gaussian(lgyr, twx0, twh)
Expand All @@ -49,26 +120,29 @@ def _age_weights_from_params(lgyr, params):


@jjit
def _age_weights_from_u_params(lgyr, u_params):
params = _get_params_from_u_params(u_params)
def _age_weights_from_u_params(lgyr, u_burst_params):
u_burst_params = BurstUParams(*u_burst_params)
params = _get_params_from_u_params(u_burst_params)
return _age_weights_from_params(lgyr, params)


@jjit
def _get_params_from_u_params(u_params):
u_lgyr_peak, u_lgyr_max = u_params
u_lgfburst, u_lgyr_peak, u_lgyr_max = u_params
lgfburst = _get_lgfburst_from_u_lgfburst(u_lgfburst)
lgyr_peak = _get_lgyr_peak_from_u_lgyr_peak(u_lgyr_peak)
lgyr_max = _get_lgyr_max_from_lgyr_peak(lgyr_peak, u_lgyr_max)
params = lgyr_peak, lgyr_max
params = lgfburst, lgyr_peak, lgyr_max
return params


@jjit
def _get_u_params_from_params(params):
lgyr_peak, lgyr_max = params
lgfburst, lgyr_peak, lgyr_max = params
u_lgfburst = _get_u_lgfburst_from_lgfburst(lgfburst)
u_lgyr_peak = _get_u_lgyr_peak_from_lgyr_peak(lgyr_peak)
u_lgyr_max = _get_u_lgyr_max_from_lgyr_peak(lgyr_peak, lgyr_max)
u_params = u_lgyr_peak, u_lgyr_max
u_params = u_lgfburst, u_lgyr_peak, u_lgyr_max
return u_params


Expand All @@ -81,6 +155,16 @@ def _get_lgyr_peak_from_u_lgyr_peak(u_lgyr_peak):
return lgyr_peak


@jjit
def _get_lgfburst_from_u_lgfburst(u_lgfburst):
return _sigmoid(u_lgfburst, LGFB_X0, LGFB_K, LGFBURST_MIN, LGFBURST_MAX)


@jjit
def _get_u_lgfburst_from_lgfburst(lgfburst):
return _inverse_sigmoid(lgfburst, LGFB_X0, LGFB_K, LGFBURST_MIN, LGFBURST_MAX)


@jjit
def _get_lgyr_max_from_lgyr_peak(lgyr_peak, u_lgyr_max):
lo, hi = lgyr_peak + DLGAGE_MIN, LGAGE_MAX
Expand All @@ -94,8 +178,8 @@ def _get_u_lgyr_peak_from_lgyr_peak(lgyr_peak):
lo = LGYR_PEAK_MIN
hi = LGAGE_MAX - DLGAGE_MIN
x0 = 0.5 * (hi + lo)
lgyr_peak = _inverse_sigmoid(lgyr_peak, x0, LGAGE_K, lo, hi)
return lgyr_peak
u_lgyr_peak = _inverse_sigmoid(lgyr_peak, x0, LGAGE_K, lo, hi)
return u_lgyr_peak


@jjit
Expand Down
67 changes: 57 additions & 10 deletions dsps/sfh/tests/test_diffburst.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
"""
"""
import numpy as np

from .. import diffburst as db


def test_params_invert():
n_tests = 10
for __ in range(n_tests):
u_p = np.random.uniform(-10, 10, 2)
u_p = np.random.uniform(-10, 10, len(db.DEFAULT_PARAMS))
p = db._get_params_from_u_params(u_p)
u_p2 = db._get_u_params_from_params(p)
assert np.allclose(u_p, u_p2, rtol=1e-3)


def test_age_weights_is_finite_and_zero_for_edge_cases():
lgyr_peak, lgyr_max = 6.5, 8.0
params = db.DEFAULT_LGFBURST, lgyr_peak, lgyr_max

# all times in the lgyr_table are after lgyr_max
lgyr_test = np.linspace(lgyr_max, lgyr_max + 2, 100)
age_weights = db._age_weights_from_params(lgyr_test, (lgyr_peak, lgyr_max))
age_weights = db._age_weights_from_params(lgyr_test, params)
assert np.all(np.isfinite(age_weights))
assert np.allclose(age_weights, 0.0)

# all times in the lgyr_table are before lgyr_min
dlgyr_support = lgyr_max - lgyr_peak
lgyr_min = lgyr_peak - dlgyr_support
lgyr_test = np.linspace(lgyr_min - 2, lgyr_min, 100)
age_weights = db._age_weights_from_params(lgyr_test, (lgyr_peak, lgyr_max))
age_weights = db._age_weights_from_params(lgyr_test, params)
assert np.all(np.isfinite(age_weights))
assert np.allclose(age_weights, 0.0)

Expand All @@ -36,11 +38,11 @@ def test_age_weights_from_params_scales_with_lgyr_max_as_expected():
lgyr_peak = 5.5
for __ in range(n_tests):
lgyr_max = np.random.uniform(7, 8)
params = lgyr_peak, lgyr_max
params = db.BurstParams(db.DEFAULT_LGFBURST, lgyr_peak, lgyr_max)

lgyr = np.linspace(4, 10.5, 100)
age_weights = db._age_weights_from_params(lgyr, params)
zmsk = lgyr > params[1]
zmsk = lgyr > params.lgyr_max
assert np.all(age_weights[zmsk] == 0)
assert np.any(age_weights > 0)

Expand All @@ -54,8 +56,12 @@ def test_age_weights_from_params_scales_with_lgyr_max_as_expected():
lgyr_peak = 5.5

lgyr_test = np.linspace(lgyr_peak, lgyr_peak + 1, 100)
age_weight_younger = db._age_weights_from_params(lgyr_test, (lgyr_peak, 6.5))
age_weights_older = db._age_weights_from_params(lgyr_test, (lgyr_peak, 8.0))
age_weight_younger = db._age_weights_from_params(
lgyr_test, (db.DEFAULT_LGFBURST, lgyr_peak, 6.5)
)
age_weights_older = db._age_weights_from_params(
lgyr_test, (db.DEFAULT_LGFBURST, lgyr_peak, 8.0)
)
assert age_weight_younger[0] > age_weights_older[0]


Expand All @@ -72,7 +78,7 @@ def test_age_weights_from_random_params_are_weights():
lgyr = np.arange(5.5, 10.35, 0.05)
n_tests = 10
for __ in range(n_tests):
u_p = np.random.uniform(-10, 10, 2)
u_p = np.random.uniform(-10, 10, len(db.DEFAULT_PARAMS))
age_weights = db._age_weights_from_u_params(lgyr, u_p)
assert np.all(np.isfinite(age_weights))
assert np.all(age_weights >= 0)
Expand All @@ -84,7 +90,7 @@ def test_age_weights_from_random_params_u_params_consistency():
lgyr = np.arange(5.5, 10.35, 0.05)
n_tests = 10
for __ in range(n_tests):
u_p = np.random.uniform(-10, 10, 2)
u_p = np.random.uniform(-10, 10, len(db.DEFAULT_PARAMS))
age_weights = db._age_weights_from_u_params(lgyr, u_p)
p = db._get_params_from_u_params(u_p)
age_weights2 = db._age_weights_from_params(lgyr, p)
Expand All @@ -110,7 +116,7 @@ def test_compute_bursty_age_weights_from_u_params():
lgyr_since_burst = np.arange(5.5, 10.35, 0.05)
n_tests = 10
for __ in range(n_tests):
u_p = np.random.uniform(-10, 10, 2)
u_p = np.random.uniform(-10, 10, len(db.DEFAULT_PARAMS))
n_age = lgyr_since_burst.size
age_weights = np.random.uniform(0, 1, n_age)
age_weights = age_weights / age_weights.sum()
Expand All @@ -133,3 +139,44 @@ def test_compute_bursty_age_weights_from_u_params():
assert np.allclose(bursty_age_weights2.sum(), 1.0, rtol=1e-3)

assert np.allclose(bursty_age_weights, bursty_age_weights2, rtol=1e-3)


def test_calc_bursty_age_weights():
burst_params = db.DEFAULT_PARAMS
n_age = 107
ssp_lg_age_gyr = np.linspace(5.0 - 9, 10.5 - 9, n_age)
smooth_age_weights = np.random.uniform(0, 1, n_age)
smooth_age_weights = smooth_age_weights / smooth_age_weights.sum()
assert np.allclose(smooth_age_weights.sum(), 1.0, rtol=1e-4)

ssp_lg_age_yr = ssp_lg_age_gyr + 9.0
burst_weights = db._age_weights_from_params(ssp_lg_age_yr, burst_params)
assert np.allclose(burst_weights.sum(), 1.0, rtol=1e-4)

bursty_age_weights = db.calc_bursty_age_weights(
burst_params, smooth_age_weights, ssp_lg_age_gyr
)
bursty_age_weights.shape == smooth_age_weights.shape
assert np.all(np.isfinite(bursty_age_weights))
assert np.allclose(bursty_age_weights.sum(), 1.0, rtol=1e-4)


def test_calc_bursty_age_weights_from_u_params():
n_age = 107
ssp_lg_age_gyr = np.linspace(5.0 - 9, 10.5 - 9, n_age)
smooth_age_weights = np.random.uniform(0, 1, n_age)
smooth_age_weights = smooth_age_weights / smooth_age_weights.sum()
n_tests = 10
for i in range(n_tests):
u_burst_params = np.random.uniform(-20, 20, len(db.DEFAULT_PARAMS))
bursty_age_weights = db.calc_bursty_age_weights_from_u_params(
u_burst_params, smooth_age_weights, ssp_lg_age_gyr
)
assert np.all(np.isfinite(bursty_age_weights))
assert np.allclose(bursty_age_weights.sum(), 1.0, rtol=1e-4)

burst_params = db._get_params_from_u_params(u_burst_params)
bursty_age_weights2 = db.calc_bursty_age_weights(
burst_params, smooth_age_weights, ssp_lg_age_gyr
)
assert np.allclose(bursty_age_weights, bursty_age_weights2, rtol=1e-4)

0 comments on commit 4702648

Please sign in to comment.