This guide summarizes some tips, tricks and practices that are useful when working with JAX for a research project. In my opinion, one key aspect that JAX is missing compared to PyTorch is a framework like PyTorch Lightning that can massively reduce code overhead while still being flexible enough for supporting almost any model/task. Although there exist such libraries for certain common tasks, like trax or scenic (attention-based CV), I have not come across one so far which was sufficiently flexible for my research. Hence, in this guide, we build a simpler version of a PyTorch Lightning trainer, that summarizes all training, logging, etc. behavior that we need for almost any model, and allows training various models with much fewer lines than from scratch. Moreover, we implement some simple examples to showcase possible training structures, and underline its flexibility by performing automatic hyperparameter tuning with Optuna. Since this guide will be about code structures, it is more code-heavy than the other guides and can also be run in Google Colab if preferred.
First, let's import some standard libraries. For this guide, we will use the data loading functionalities of PyTorch, but one could also use the TensorFlow dataset API. Additionally, we integrate loggers from PyTorch Lightning since they support a flexible API and have most popular logging application implemented (e.g. TensorBoard, Weights and Biases).
Flax gives us already some basic functionalities for training models. One part of it is the TrainState, which holds the model parameters and optimizers, and allows updating it. However, there might be more model aspects that we would like to add to the TrainState. For instance, if a model uses Batch Normalization, we need to keep the batch statistics in order to evaluate the models on a test dataset. Furthermore, many models contain stochastic elements such as dropout or sampling in generative models (e.g. Normalizing Flows). Thus, we extend the TrainState class from Flax to also include the batch statistics as batch_stats and a pseudo-random number generation rng. Note that if models do not require these elements, they can simply be None without breaking our code.
Check TrainState class.
Now we already come to the main part of this guide: the Trainer module for JAX/Flax. The shown module here is not meant to be the 'one and only' way of doing it, and is more meant as showcasing one possible option of obtaining a Lightning-like API in JAX. The module can easily be extended by more functionalities, depending on what is needed/preferred by the individual users.
First let's make a list of functionalities that we would want the Trainer module to include:
- Logging: For basically all usecases and models, we want to log our hyperparameters, training/validation performance, and model checkpoints. For the second point, we can make use of PyTorch Lightning's logger classes like
TensorBoardLogger
andWandbLogger
. For the model checkpoints, we useflax.checkpoints
. In terms of flexibility, the trainer should support arbitrary sets of hyperparameters, since different models may require different hyperparameters. Similarly, it should be easy to add new metrics for logging, like accuracy for classification or intersection over union for segmentation.- Implemented in:
init_logger
,save_model
,load_model
,save_metrics
- Implemented in:
- Model state initialization: In contrast to PyTorch, JAX separates the model itself from the learnable parameters. Creating a set of parameters for a model requires some boiler-template code, like creating a PRNG for the parameter generation and creating an initial
TrainState
. At the same time, we need to allow overwriting themodel.init
code, since different architectures will have different input arguments for the forward pass (e.g. models with dropout require a dropout-PRNG).- Implemented in:
init_model
,run_model_init
,print_tabulate
- Implemented in:
- Optimizer initialization: Following with the parameter initialization, we also need to create an optimizer and its eventual parameters (e.g. momentum and adaptive learning rate parameters in Adam). Since most models use a similar set of optimizers (SGD or Adam) and extra functionalities like gradient clipping and learning rate scheduling, we can write a template method that creates an optimizers based on some hyperparameters. However, it should be possible to overwrite this method if very specific optimizer settings/learning rate schedulers are needed. Since some schedulers require information about the overall number of training iterations, we create the optimizer right before starting the training.
- Implemented in:
init_optimizer
- Implemented in:
- Training loop: Most models follow a similar training procedure where we train a model for several epoch on the training dataset, and evaluate it in between on the validation dataset. If a model is better than all previous models, we want to save its weight for loading them potentially later. Importantly, however, each model will have a very different training and validation step. Thus, similarly to PyTorch Lightning, we expect that an inheriting Trainer module has to define a training step function and evaluation step function, that can be jitted and used in the training loop. This is implemented in the function
train_model
,train_epoch
,eval_model
,create_functions
,create_jitted_functions
. Additional aspects to consider include:- Whether a model is better than the previous ones or not depends on the task at hand. For example, classification models are usually compared by their accuracy, trying to achieve the maximum value, while regression models aim for the lowest loss. Hence, we need a flexible API to support different ways of comparing models and finding the best one. Implemented in:
is_new_model_better
- Within the training loop, we might want to perform additional operations, like logging reconstruction examples of an autoencoder after every few epochs. To do so, PyTorch Lightning provides functions that are called at different stages during training, which we can similarly integrate in our Trainer module. Implemented in:
on_training_start
,on_training_epoch_end
,on_validation_epoch_end
- Depending on whether we run the model on a cluster with no display or on our local machine, we might want to see progress bars that track the training progress. Hence, the Trainer module should have to switch to enable or disable these progress bars. Implemented in:
tracker
- Whether a model is better than the previous ones or not depends on the task at hand. For example, classification models are usually compared by their accuracy, trying to achieve the maximum value, while regression models aim for the lowest loss. Hence, we need a flexible API to support different ways of comparing models and finding the best one. Implemented in:
- Inference: After we have finished training, we might want to load a model at a later time and perform inference experiments with it. To support this, two functionalities are needed: (1) loading a model from disk, including its hyperparameters (i.e. the function
load_from_checkpoint
in PyTorch Lightning), and (2) binding parameters to a model to reduce code overhead. Both parts can be implemented in our Trainer module.- Implemented in:
load_from_checkpoint
,bind_model
- Implemented in:
With these requirements in mind, the module could be implemented systematically. Note that it is a considerably long code since we want to support many different settings. it is recommended to take some time to go through the code and understand how all the elements are implemented, and how one can extend it depending on their own needs.
- In ResNet experiments, weight_decay argument value is always provided (defaults to 0) and is always be used as an additive term in the loss function. With Adam, this should always be set to zero. To use L2Reg with Adam, it is preferrable to use AdamW, but in this case, weight decay will be part of both the loss and optax.adamw hyperparams. This is defended against by exception, but ideally the optimizer should be configured and passed by the user (correctness is user responsibility). TODO: design a common optimizer interface that the user can configure? How would this fit with how VeLO is working?