Skip to content

Commit

Permalink
Merge pull request #78 from ArgonneCPAC/trapz_smh
Browse files Browse the repository at this point in the history
Trapezoidal integration of stellar age weights
  • Loading branch information
aphearin authored Nov 28, 2023
2 parents 31b03ab + f3001e6 commit 20a2c40
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
13 changes: 6 additions & 7 deletions dsps/sed/stellar_age_weights.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Kernels calculating stellar age PDF-weighting of SSP tempates"""
from jax import numpy as jnp
from jax import jit as jjit
from ..utils import _jax_get_dt_array
from ..constants import SFR_MIN, T_BIRTH_MIN, N_T_LGSM_INTEGRATION
from jax import numpy as jnp

from ..constants import N_T_LGSM_INTEGRATION, SFR_MIN, T_BIRTH_MIN
from ..cosmology import TODAY
from ..utils import _jax_get_dt_array, cumulative_mstar_formed

__all__ = ("calc_age_weights_from_sfh_table",)

Expand Down Expand Up @@ -170,9 +171,7 @@ def _calc_logsm_table_from_sfh_table(gal_t_table, gal_sfr_table, sfr_min):
Minimum star formation rate in Msun/yr
"""
dt_table = _jax_get_dt_array(gal_t_table)

gal_sfr_table = jnp.where(gal_sfr_table < sfr_min, sfr_min, gal_sfr_table)
gal_mstar_table = jnp.cumsum(gal_sfr_table * dt_table) * 1e9
logsm_table = jnp.log10(gal_mstar_table)
gal_smh_table = cumulative_mstar_formed(gal_t_table, gal_sfr_table)
logsm_table = jnp.log10(gal_smh_table)
return logsm_table
12 changes: 7 additions & 5 deletions dsps/sed/tests/test_csp_sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
"""
import numpy as np
from jax import random as jran
from ..stellar_sed import calc_rest_sed_sfh_table_lognormal_mdf
from ..stellar_sed import calc_rest_sed_sfh_table_met_table
from ...constants import T_BIRTH_MIN

from ...constants import T_TABLE_MIN
from ..stellar_sed import (
calc_rest_sed_sfh_table_lognormal_mdf,
calc_rest_sed_sfh_table_met_table,
)

SEED = 43
FSPS_LG_AGES = np.arange(5.5, 10.2, 0.05) # log10 ages in years
Expand All @@ -15,7 +17,7 @@ def test_calc_rest_sed_lognormal_mdf():
ran_key = jran.PRNGKey(SEED)
t_obs = 13.0
n_t = 500
gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t)
gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t)

sfr_key, met_key, sed_key = jran.split(ran_key, 3)
gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,))
Expand Down Expand Up @@ -63,7 +65,7 @@ def test_calc_rest_sed_lgmet_table():
ran_key = jran.PRNGKey(SEED)
t_obs = 13.0
n_t = 500
gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t)
gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t)

sfr_key, met_key, sed_key = jran.split(ran_key, 3)
gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,))
Expand Down
12 changes: 7 additions & 5 deletions dsps/sed/tests/test_ssp_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
"""
import numpy as np
from jax import random as jran
from ..ssp_weights import calc_ssp_weights_sfh_table_lognormal_mdf
from ..ssp_weights import calc_ssp_weights_sfh_table_met_table
from ...constants import T_BIRTH_MIN

from ...constants import T_TABLE_MIN
from ..ssp_weights import (
calc_ssp_weights_sfh_table_lognormal_mdf,
calc_ssp_weights_sfh_table_met_table,
)

SEED = 43
FSPS_LG_AGES = np.arange(5.5, 10.2, 0.05) # log10 ages in years
Expand All @@ -15,7 +17,7 @@ def test_calc_ssp_weights_lognormal_mdf():
ran_key = jran.PRNGKey(SEED)
t_obs = 13.0
n_t = 500
gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t)
gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t)

sfr_key, met_key = jran.split(ran_key, 2)
gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,))
Expand Down Expand Up @@ -59,7 +61,7 @@ def test_calc_ssp_weights_met_table():
ran_key = jran.PRNGKey(SEED)
t_obs = 13.0
n_t = 500
gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t)
gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t)

sfr_key, met_key = jran.split(ran_key, 2)
gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,))
Expand Down

0 comments on commit 20a2c40

Please sign in to comment.