Skip to content

A numpyro Gibbs sampler that uses conditioned HMC kernels for each step.

License

Notifications You must be signed in to change notification settings

CKrawczyk/MultiHMCGibbs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

28 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DOI

An HMC-within-Gibbs sampler for Numpyro

This package adds a new HMC-within-Gibbs sampler to Numpyro. Unlike the HMCGibbs sampler currently available, this sampler is for situations where you do not have an analytic form for one of your conditioned distributions. Instead, it uses an HMC/NUTS sampler to estimate draws from each of the conditioned distributions.

To use MultiHMCGibbs you need to create a list of HMC or NUTS kernels that wrap the same model, but each can have its own keywords such as target_accept_prob or max_tree_depth. The other argument is a list of lists containing the free parameters for each of the inner kernels.

Internally the sampler will:

  1. Loop over the kernels in the list
  2. Conditioned it on the non-free parameters
  3. Re-calculate the likelihood and gradients at the new conditioned point
  4. Step the kernel forward
  5. Move on to the next kernel

Documentation: https://ckrawczyk.github.io/MultiHMCGibbs/

GitHub: https://github.com/CKrawczyk/MultiHMCGibbs

Installation

You can install the package with pip after cloning the repository.

git clone https://github.com/CKrawczyk/MultiHMCGibbs.git
cd MultiHMCGibbs
pip install .

Example usage

from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from MultiHMCGibbs import MultiHMCGibbs

def model():
     x = numpyro.sample("x", dist.Normal(0.0, 2.0))
     y = numpyro.sample("y", dist.Normal(0.0, 2.0))
     numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))

inner_kernels = [
    NUTS(model),
    NUTS(model)
]
outer_kernel = MultiHMCGibbs(
    inner_kernels,
    [['y'], ['x']]
)
mcmc = MCMC(
    kernel,
    num_warmup=100,
    num_samples=100,
    progress_bar=False
)
mcmc.run(random.PRNGKey(0))
mcmc.print_summary()

Contributing

Install all the development dependencies:

pip install -e .[dev]

Run tests with:

coverage run
coverage report

Build documentation with:

./build_docs

Citation

If you use this sampler in your publication you can cite the software as:

Coleman Krawczyk. (2024). CKrawczyk/MultiHMCGibbs: v1.0.0 (v1.0.0). Zenodo. https://doi.org/10.5281/zenodo.12167630

Or with bibtex:

@software{coleman_krawczyk_2024_12167630,
  author       = {Coleman Krawczyk},
  title        = {CKrawczyk/MultiHMCGibbs: v1.0.0},
  month        = jun,
  year         = 2024,
  publisher    = {Zenodo},
  version      = {v1.0.0},
  doi          = {10.5281/zenodo.12167630},
  url          = {https://doi.org/10.5281/zenodo.12167630}
}

Full citation information can be found on https://zenodo.org/records/12167630

About

A numpyro Gibbs sampler that uses conditioned HMC kernels for each step.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 3

  •  
  •  
  •