Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft of switch StudentT cdf to use tfp's betainc #1475

Merged
merged 3 commits into from
Sep 6, 2022

Conversation

colehaus
Copy link
Contributor

@colehaus colehaus commented Sep 5, 2022

Jax's betainc doesn't have gradients defined for all parameters while tfp's does.

See the related PR here: #1471 and the initial discussion here: #1452.

I'm not sure exactly how you want to handle the dependency declarations since tensorflow and tensorflow-probability are sort of heavy dependencies to bring in (i.e. should they be promoted to install_requires?).

Also, the type casting stuff is a bit ugly but tfp checks that array types match and self.df sometimes had a float64 dtype in tests while beta_value has a float32 dtype in each test.

Jax's `betainc` doesn't have gradients defined for all parameters while tfp's does
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
@colehaus
Copy link
Contributor Author

colehaus commented Sep 6, 2022

Ah, sorry. Was slightly non-trivial to run the lint checks with a .venv, but I cherry-picked my config changes from the other PR and ran make lint locally so it should pass this time.

@fehiepsi
Copy link
Member

fehiepsi commented Sep 6, 2022

This is a great addition! Thanks, @colehaus.

@fehiepsi fehiepsi merged commit 9d5d235 into pyro-ppl:master Sep 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants