From 7fb84f6bbff15a05039907690121382cef329676 Mon Sep 17 00:00:00 2001 From: Cole Haus Date: Mon, 5 Sep 2022 14:40:19 -0700 Subject: [PATCH 1/3] Switch `StudentT` `cdf` to use tfp's `betainc` Jax's `betainc` doesn't have gradients defined for all parameters while tfp's does --- numpyro/distributions/continuous.py | 25 ++++++++++++++++++++++++- setup.py | 2 +- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index fa05f6951..a9951c372 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -27,6 +27,7 @@ import numpy as np +import jax from jax import lax from jax.experimental.sparse import BCOO import jax.nn as nn @@ -1909,18 +1910,40 @@ def variance(self): return jnp.broadcast_to(var, self.batch_shape) def cdf(self, value): + from tensorflow_probability.substrates.jax.math import betainc as tfp_betainc + # Ref: https://en.wikipedia.org/wiki/Student's_t-distribution#Related_distributions # X^2 ~ F(1, df) -> df / (df + X^2) ~ Beta(df/2, 0.5) scaled = (value - self.loc) / self.scale scaled_squared = scaled * scaled beta_value = self.df / (self.df + scaled_squared) + + float_type = ( + jnp.promote_types(self.df.dtype, beta_value.dtype) + if jax.config.read("jax_enable_x64") + else jnp.dtype("float32") + ) + # when scaled < 0, returns 0.5 * Beta(df/2, 0.5).cdf(beta_value) # when scaled > 0, returns 1 - 0.5 * Beta(df/2, 0.5).cdf(beta_value) scaled_sign_half = 0.5 * jnp.sign(scaled) return ( 0.5 + scaled_sign_half - - 0.5 * jnp.sign(scaled) * betainc(0.5 * self.df, 0.5, beta_value) + - 0.5 + * jnp.sign(scaled) + * tfp_betainc( + 0.5 + * ( + self.df.astype(float_type) + if isinstance(self.df, (jnp.ndarray, np.ndarray)) + else self.df + ), + 0.5, + beta_value.astype(float_type) + if isinstance(self.df, (jnp.ndarray, np.ndarray)) + else beta_value, + ) ) def icdf(self, q): diff --git a/setup.py b/setup.py index f9ccf0901..e78170d1f 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ "jaxns==1.0.0", "optax>=0.0.6", "pyyaml", # flax dependency - "tensorflow_probability>=0.15.0", + "tensorflow_probability>=0.17.0", ], "examples": [ "arviz", From d73699184b3945bef012d5fb8a5b00ca5705948b Mon Sep 17 00:00:00 2001 From: Cole Haus Date: Mon, 5 Sep 2022 19:04:08 -0700 Subject: [PATCH 2/3] Cleanup in line with PR review --- numpyro/distributions/continuous.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index a9951c372..96891a094 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -1910,7 +1910,10 @@ def variance(self): return jnp.broadcast_to(var, self.batch_shape) def cdf(self, value): - from tensorflow_probability.substrates.jax.math import betainc as tfp_betainc + try: + from tensorflow_probability.substrates.jax.math import betainc as betainc_fn + except ImportError: + from jax.scipy.special import betainc as betainc_fn # Ref: https://en.wikipedia.org/wiki/Student's_t-distribution#Related_distributions # X^2 ~ F(1, df) -> df / (df + X^2) ~ Beta(df/2, 0.5) @@ -1918,12 +1921,6 @@ def cdf(self, value): scaled_squared = scaled * scaled beta_value = self.df / (self.df + scaled_squared) - float_type = ( - jnp.promote_types(self.df.dtype, beta_value.dtype) - if jax.config.read("jax_enable_x64") - else jnp.dtype("float32") - ) - # when scaled < 0, returns 0.5 * Beta(df/2, 0.5).cdf(beta_value) # when scaled > 0, returns 1 - 0.5 * Beta(df/2, 0.5).cdf(beta_value) scaled_sign_half = 0.5 * jnp.sign(scaled) @@ -1932,18 +1929,7 @@ def cdf(self, value): + scaled_sign_half - 0.5 * jnp.sign(scaled) - * tfp_betainc( - 0.5 - * ( - self.df.astype(float_type) - if isinstance(self.df, (jnp.ndarray, np.ndarray)) - else self.df - ), - 0.5, - beta_value.astype(float_type) - if isinstance(self.df, (jnp.ndarray, np.ndarray)) - else beta_value, - ) + * betainc_fn(0.5 * jnp.asarray(self.df), 0.5, jnp.asarray(beta_value)) ) def icdf(self, q): From 219122308ec9aead8f55f573fea5388b65a9c07d Mon Sep 17 00:00:00 2001 From: Cole Haus Date: Tue, 6 Sep 2022 09:30:36 -0700 Subject: [PATCH 3/3] Remove unneeded import as directed by linter --- numpyro/distributions/continuous.py | 1 - 1 file changed, 1 deletion(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 96891a094..d8442794c 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -27,7 +27,6 @@ import numpy as np -import jax from jax import lax from jax.experimental.sparse import BCOO import jax.nn as nn