Experiment code to reproduce results on Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective (ICML 2024)
We provide a prototype implementation of the root-free RMSProp and inverse-free Shampoo.
For a clean implementation of the inverse-free Shampoo, please check out this repository.
We use PyTorch’s built-in SGD, AdamW, and RMSProp. For Shampoo, we rely on the state-of-the-art PyTorch implementation from Meta (Shi et al., 2023). We tune hyperparameters (HPs) for each optimizer (see the following HP search space)
For matrix adaptive methods (Shampoo and inverse-free Shampoo), we update their matrix preconditioners at each two iterations. By updating the preconditioners less frequently, we can further reduce their wall clock time.
We employ a two-stage HP tuning protocol for all tasks and optimizers based on random search (Choi et al., 2019).
Unlike Choi et al., 2019, we only consider a small damping
term (e.g., 0 < λ < 5e−4) for all methods in our HP search space since a large damping term (e.g., λ >1) can turn Adam into SGD.
In the first stage, we use larger search regimes for all HPs. Based on this stage, we select a narrower HP range and re-run the search, reporting the best run for each method. We use 100 runs in each stage.
HP search space used in our paper: CNNs, SwinViT, FocalNet, GCViT, VMamba, LSTM, GNN
Note: beta2
in AdamW is equivalent to 1-lr_cov
in our notation. eps
in AdamW is equivalent to damping
in our notation.
For all optimizers, only the forward pass is executed in mixed precision with BFP-16
(as
recommended by the official PyTorch guide). The gradients are automatically cast back to FP-32 by PyTorch. Shampoo uses
these FP-32
gradients for its preconditioner and is unstable when converting them to BFP-16 (Shi et al., 2023). Instead, our
IF-Shampoo converts the gradients into BFP-16
, updates the preconditioner, and even takes preconditioned gradient steps (including momentum) in
half-precision. Our method works well in half-precision without
using matrix decompositions
and matrix solve/inversions
.
Note:
- These matrix operations (e.g., eigen, Cholesky, SVD, inversion) in
half-precision
are not supported in PyTorch and JAX because they are numerically unstable (see discussions on inversion, SVD, Cholesky). - In practice, using one eigen decomposition in
Float32
is 16 times slower than one matrix multiplication inBFloat16
.
- add all NN models and training scripts considered in our paper