Skip to content

haydn-jones/DiffJPEG-JAX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DiffJPEG: A Jax Implementation

This is a Jax implementation of the differentiable JPEG compression algorithm, based on the PyTorch implementation and some of the modifications found in this repository to improve quality at high compression rates.

Requirements

  • JAX

Installation

Can be installed with pip:

pip install diffjpeg_jax

Usage

Unlike the PyTorch version, this is ML library agnostic, so it simply is implemented as a function. Inputs should be in the range [0, 255] and in the format (H, W, C).

from diffjpeg_jax import diff_jpeg

img = ... # (H, W, C)
jpeg = diff_jpeg(img, quality=75)

Note: The implementation is not wrapped in JIT, so make sure to do that if you want to. For batch processing just use vmap.