Skip to content

Commit

Permalink
jax.lax.logistic-->jax.nn.sigmoid
Browse files Browse the repository at this point in the history
  • Loading branch information
aphearin committed Feb 9, 2024
1 parent 8cc63b2 commit f84bb93
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 29 deletions.
9 changes: 2 additions & 7 deletions diffmah/individual_halo_assembly.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Model for individual halo mass assembly based on a power-law with rolling index."""

from jax import grad
from jax import jit as jjit
from jax import lax
from jax import numpy as jnp
from jax import vmap

from .defaults import LGT0, MAH_K
from .utils import get_1d_arrays
from .utils import _sigmoid, get_1d_arrays


@jjit
Expand Down Expand Up @@ -137,12 +138,6 @@ def _softplus(x):
return jnp.log(1 + lax.exp(x))


@jjit
def _sigmoid(x, logtc, k, ymin, ymax):
height_diff = ymax - ymin
return ymin + height_diff / (1.0 + lax.exp(-k * (x - logtc)))


@jjit
def _get_early_late(ue, ul):
late = _softplus(ul)
Expand Down
9 changes: 2 additions & 7 deletions diffmah/rockstar_pdf_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Model of halo population assembly calibrated to Rockstar halos."""

from collections import OrderedDict

from jax import jit as jjit
Expand All @@ -7,7 +8,7 @@
from jax import vmap

from .defaults import MAH_K
from .utils import get_cholesky_from_params
from .utils import _sigmoid, get_cholesky_from_params

TODAY = 13.8
LGT0 = jnp.log10(TODAY)
Expand Down Expand Up @@ -56,12 +57,6 @@
)


@jjit
def _sigmoid(x, logtc, k, ymin, ymax):
height_diff = ymax - ymin
return ymin + height_diff / (1.0 + lax.exp(-k * (x - logtc)))


def _get_cov_scalar(
log10_lge_lge,
log10_lgl_lgl,
Expand Down
9 changes: 5 additions & 4 deletions diffmah/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
"""
"""

import numpy as np
from jax import jit as jax_jit
from jax import numpy as jax_np
from jax import value_and_grad

from ..utils import (
_inverse_sigmoid,
_sigmoid,
get_cholesky_from_params,
jax_adam_wrapper,
jax_inverse_sigmoid,
jax_sigmoid,
)


def test_inverse_sigmoid_actually_inverts():
""""""
x0, k, ylo, yhi = 0, 5, 1, 0
xarr = np.linspace(-1, 1, 100)
yarr = np.array(jax_sigmoid(xarr, x0, k, ylo, yhi))
xarr2 = np.array(jax_inverse_sigmoid(yarr, x0, k, ylo, yhi))
yarr = np.array(_sigmoid(xarr, x0, k, ylo, yhi))
xarr2 = np.array(_inverse_sigmoid(yarr, x0, k, ylo, yhi))
assert np.allclose(xarr, xarr2, rtol=1e-3)


Expand Down
9 changes: 2 additions & 7 deletions diffmah/tng_pdf_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
"""

from collections import OrderedDict

from jax import jit as jjit
Expand All @@ -8,7 +9,7 @@
from jax import vmap

from .defaults import MAH_K
from .utils import get_cholesky_from_params
from .utils import _sigmoid, get_cholesky_from_params

TODAY = 13.8
LGT0 = jnp.log10(TODAY)
Expand Down Expand Up @@ -57,12 +58,6 @@
)


@jjit
def _sigmoid(x, logtc, k, ymin, ymax):
height_diff = ymax - ymin
return ymin + height_diff / (1.0 + lax.exp(-k * (x - logtc)))


def _get_cov_scalar(
log10_lge_lge,
log10_lgl_lgl,
Expand Down
11 changes: 7 additions & 4 deletions diffmah/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Utility functions used throughout the package."""

import numpy as np
from jax import jit as jjit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax.example_libraries import optimizers as jax_opt

Expand All @@ -25,7 +26,8 @@ def get_1d_arrays(*args, jax_arrays=False):
return result


def jax_sigmoid(x, x0, k, ylo, yhi):
@jjit
def _sigmoid(x, x0, k, ylo, yhi):
"""Sigmoid function implemented w/ `jax.numpy.exp`.
Parameters
Expand All @@ -45,10 +47,11 @@ def jax_sigmoid(x, x0, k, ylo, yhi):
-------
sigmoid : scalar or array-like, same shape as input
"""
return ylo + (yhi - ylo) / (1 + lax.exp(-k * (x - x0)))
return ylo + (yhi - ylo) * nn.sigmoid(k * (x - x0))


def jax_inverse_sigmoid(y, x0, k, ylo, yhi):
@jjit
def _inverse_sigmoid(y, x0, k, ylo, yhi):
"""Sigmoid function implemented w/ `jax.numpy.exp`.
Parameters
Expand Down

0 comments on commit f84bb93

Please sign in to comment.