Skip to content

Latest commit

 

History

History

jaxrl_m

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

A JAX Backbone for RL projects

Note: this folder is an exact copy of [dibyaghosh/jaxrl_m

This project serves as a "central backbone" for an RL codebase, designed to accelerate prototyping and diagnosis of new algorithms (although it auxiliarily does contain reference implementations of SAC, CQL, IQL, BC). It borrows greatly from Ilya Kostrikov's JaxRL codebase.

The primary goal of the codebase is to make ease of coding up a new algorithm: towards this goal, the primary philosophy is that

algorithms should be single-file implementations

This means that (almost) all components of the algorithm (from update rule to network choices to hyperparameter choices) are all contained in one file (e.g. see BC example). This makes it easy to read and understand the algorithm, and also makes it easy to modify the algorithm to test out new ideas. The code is also designed to scale as easily as possible to multi-GPU / TPU setups, with simple abstractions for distributed training.

Installation

Requires jax, flax, optax, distrax, and optionally wandb for logging. Clone this repository and install it (e.g. pip install -e .) or add to python path.

Usage

The fastest way to understand how to use this skeleton is to see the reference SAC implementation:

Agent: sac.py

Launcher: mujoco_sac.py

Structure

The code contains the following files:

  • jaxrl_m.common: Contains the TrainState abstraction (a fork of Flax's TrainState class with some additional syntactic features for ease of use), and some other useful general utilities (target_update, shard_batch)
  • jaxrl_m.dataset: Contains the Dataset class (which can store and sample from buffers containing arbitrarily nested dictionaries) and an equivalent ReplayBuffer class
  • jaxrl_m.networks: Contains implementations of common RL networks (MLP, Critic, ValueCritic, Policy)
  • jaxrl_m.evaluation: Contains code for running evaluation episodes of agents (e.g. with the evaluate(policy, env) function)
  • jaxrl_m.wandb: Contains code for easily setting up Weights & Biases for experiments
  • jaxrl_m.typing: Useful type aliases
  • jaxrl_m.vision: vision.models contains common vision models (e.g. ResNet, ResNetV2, Impala), vision.data_augmentations contains common augmentations (e.g. random crop, random color jitter, gaussian blur)

Examples

Example implementations:

  1. Continuous BC
  2. Discrete BC
  3. SAC
  4. IQL

Example Launchers:

  1. Mujoco SAC
  2. D4RL IQL