PyTorch implementation of Model-Agnostic Meta-Learning[1] with gradient checkpointing[2]. It allows you to perform way (~10-100x) more MAML steps with the same GPU memory budget.
For normal installation, run
pip install torch_maml
For development installation, clone a repo and
python setup.py develop
See examples in example.ipynb
-
Make sure that your model doesn't have implicit parameter updates like torch.nn.BatchNorm2d under track_running_stats=True. With gradient checkpointing, these updates will be performed twice (once per forward pass). If still want these updates, take a look at
torch_maml.utils.disable_batchnorm_stats
. Note that we already support this for vanilla BatchNorm{1-3}d. -
When computing gradients through many MAML steps (e.g. 100 or 1000), you should care about vanishing and exploding gradients within optimizers (same as in RNN). This implementation supports gradient clipping to avoid the explosive part of the problem.
-
Also, when you deal with a large number of MAML steps, be aware of accumulating computational error due to float precision and specifically CUDNN operations. We recommend you to use
torch.backend.cudnn.determistic=True
. The problem appears when gradients become slightly noisy due to errors, and, during backpropagation through MAML steps, the error is likely to increase dramatically. -
You could also consider Implicit Gradient MAML [3] for memory efficient meta-learning alternatives. While this algorithm requires even less memory, it assumes that your optimization converges to the optimum. Therefore, it is inapplicable if your task does not always converge by the time you start backpropagating. In contrast, our implementation allows you to meta-learn even from a partially converged state.
[1] Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks