PIX is an image processing library in JAX, for JAX.
JAX is a library resulting from the union of Autograd and XLA for high-performance machine learning research. It provides NumPy, SciPy, automatic differentiation and first-class GPU/TPU support.
PIX is a library built on top of JAX with the goal of providing image processing
functions and tools to JAX in a way that they can be optimised and parallelised
through jax.jit
, jax.vmap
and jax.pmap
.
PIX is written in pure Python, but depends on C++ code via JAX.
Because JAX installation is different depending on your CUDA version, PIX does
not list JAX as a dependency in pyproject.toml
, although it is technically
listed for reference, but commented.
First, follow JAX installation instructions to install JAX with the relevant accelerator support.
Then, install PIX using pip
:
$ pip install dm-pix
To use PIX
, you just need to import dm_pix as pix
and use it right away!
For example, let's assume to have loaded the JAX logo (available in
examples/assets/jax_logo.jpg
) in a variable called image
and we want to flip
it left to right.
All it's needed is the following code!
import dm_pix as pix
# Load an image into a NumPy array with your preferred library.
image = load_image()
flip_left_right_image = pix.flip_left_right(image)
And here is the result!
All the functions in PIX can be jax.jit
ed, jax.vmap
ed and
jax.pmap
ed, so all the following functions can take advantage of
optimization and parallelization.
import dm_pix as pix
import jax
# Load an image into a NumPy array with your preferred library.
image = load_image()
# Vanilla Python function.
flip_left_right_image = pix.flip_left_right(image)
# `jax.jit`ed function.
flip_left_right_image = jax.jit(pix.flip_left_right)(image)
# Assuming to have a single device, like a CPU or a single GPU, we add a
# single leading dimension for using `image` with the parallelized or
# the multi-device parallelization version of `pix.flip_left_right`.
# To know more, please refer to JAX documentation of `jax.vmap` and `jax.pmap`.
image = image[np.newaxis, ...]
# `jax.vmap`ed function.
flip_left_right_image = jax.vmap(pix.flip_left_right)(image)
# `jax.pmap`ed function.
flip_left_right_image = jax.pmap(pix.flip_left_right)(image)
You can check it yourself that the result from the four versions of
pix.flip_left_right
is the same (up to the accelerator floating point
accuracy)!
We have a few examples in the examples/
folder. They are not much
more involved then the previous example, but they may be a good starting point
for you!
We provide a suite of tests to help you both testing your development
environment and to know more about the library itself! All test files have
_test
suffix, and can be executed using pytest
.
If you already have PIX installed, you just need to install some extra
dependencies and run pytest
as follows:
$ pip install -e ".[test]"
$ python -m pytest [-n <NUMCPUS>] dm_pix
If you want an isolated virtual environment, you just need to run our utility
bash
script as follows:
$ ./test.sh
This repository is part of the DeepMind JAX Ecosystem, to cite PIX please use the DeepMind JAX Ecosystem citation.
We are very happy to accept contributions!
Please read our contributing guidelines and send us PRs!