What is SynJax? | Installation | Examples | Citing SynJax
SynJax is a neural network library for JAX structured probability distributions. The distributions that are currently supported are:
- Linear Chain CRF,
- Semi-Markov CRF,
- Constituency Tree CRF,
- Spanning Tree CRF -- including optional constraints for projectivity, (un)directionality and single root edges,
- Alignment CRF -- including both monotonic (1-to-many and many-to-many) and non-monotonic (1-to-1) alignments,
- CTC Alignment.
- PCFG,
- Tensor-Decomposition PCFG,
- HMM,
All these distributions support standard operations such as computing log-probability of a structure, computing marginal probability of a part of the structure, finding most likely structure, sampling, top-k, entropy, cross-entropy, kl-divergence...
All operations support standard JAX transformations jax.vmap
, jax.jit
, jax.pmap
and jax.grad
. The only exception are argmax, sample and top-k that do not support jax.grad
.
If you would like to read about the details of SynJax take a look at the paper.
SynJax is written in pure Python, but depends on C++ code via JAX.
Because JAX installation is different depending on your CUDA version,
SynJax does not list JAX as a dependency in requirements.txt
.
First, follow these instructions to install JAX with the relevant accelerator support.
Then, install SynJax using pip:
$ pip install git+https://github.com/google-deepmind/synjax
The notebooks directory contains examples of how Synjax works:
To cite SynJax please use both SynJax paper citation:
@article{synjax2023,
title="{SynJax: Structured Probability Distributions for JAX}",
author={Milo\v{s} Stanojevi\'{c} and Laurent Sartran},
year={2023},
journal={arXiv preprint arXiv:2308.03291},
url={https://arxiv.org/abs/2308.03291},
}
and the current DeepMind JAX Ecosystem citation.