Skip to content

Latest commit

 

History

History
64 lines (45 loc) · 3.39 KB

README.md

File metadata and controls

64 lines (45 loc) · 3.39 KB

SynJax

What is SynJax? | Installation | Examples | Citing SynJax

What is SynJax?

SynJax is a neural network library for JAX structured probability distributions. The distributions that are currently supported are:

All these distributions support standard operations such as computing log-probability of a structure, computing marginal probability of a part of the structure, finding most likely structure, sampling, top-k, entropy, cross-entropy, kl-divergence...

All operations support standard JAX transformations jax.vmap, jax.jit, jax.pmap and jax.grad. The only exception are argmax, sample and top-k that do not support jax.grad.

If you would like to read about the details of SynJax take a look at the paper.

Installation

SynJax is written in pure Python, but depends on C++ code via JAX. Because JAX installation is different depending on your CUDA version, SynJax does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install SynJax using pip:

$ pip install git+https://github.com/google-deepmind/synjax

Examples

The notebooks directory contains examples of how Synjax works:

  • Introductory notebook demonstrating SynJax functionalities. Open In Colab

Citing SynJax

To cite SynJax please use both SynJax paper citation:

@article{synjax2023,
      title="{SynJax: Structured Probability Distributions for JAX}",
      author={Milo\v{s} Stanojevi\'{c} and Laurent Sartran},
      year={2023},
      journal={arXiv preprint arXiv:2308.03291},
      url={https://arxiv.org/abs/2308.03291},
}

and the current DeepMind JAX Ecosystem citation.