Skip to content

Commit

Permalink
feat: add backpack backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ClashLuke committed Nov 26, 2022
1 parent 90ed2d4 commit 0a14e38
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 61 deletions.
97 changes: 67 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,62 @@ python3 -m pip install truegrad

## Examples

### BackPack

The preferred method to integrate TrueGrad is using [BackPack](https://github.com/f-dangel/backpack). BackPack is a
third-party library that automatically computes the sum of gradient squares and works for most models by implementing
custom backward rules for many `torch.nn.Module`'s.

```PYTHON
import backpack
import torch
from torch.nn import CrossEntropyLoss
from truegrad.optim import TGAdamW
from torchvision.models import alexnet

model = alexnet()
optim = TGAdamW(model.parameters(), lr=1e-7, weight_decay=0)

# backpack can't handle inplace ops like nn.ReLU(inplace=True) and `x += y`
for mod in model.modules():
if hasattr(mod, "inplace"):
mod.inplace = False

# backpack relies on module-level pytorch hooks
model = backpack.extend(model)
lossfunc = backpack.extend(CrossEntropyLoss())

# constant input/output to overfit
inp = torch.randn((2, 3, 224, 224))
tgt = torch.randint(0, 1000, (2,))

# standard training loop
i = 0
while True:
# "SumGradSquared" computes the sum of the squared gradient
with backpack.backpack(backpack.extensions.SumGradSquared()):
loss = lossfunc(model(inp), tgt)
loss.backward()
optim.step()
i += 1
if i % 5 == 0:
print(i, loss.item())
```

If you're using custom modules with self-defined parameters, this method will not work. Additionally, note that, if
your model has any layer called `.output` or you're using PyTorch >= 1.13, you will need to install
[BackPack-HF](https://github.com/ClashLuke/backpack-hf) via
`python3 -m pip install git+https://github.com/ClashLuke/backpack-hf`.

### Patch Custom Models

The easiest way to integrate TrueGrad into existing models is to patch them using `truegrad.utils.patch_model()`.
Another option to integrate TrueGrad into existing models is to patch them using `truegrad.utils.patch_model()`.
`patch_model()` will go through all`torch.nn.Module`'s in PyTorch model and convert their `torch.nn.Parameter`'s to
`truegrad.nn.TrueGradParameter`'s. A `TrueGradParameter` acts largely the same as a `torch.nn.Parameter`, but adds
required operations into the model's backward pass.\
Patching an existing
Importantly, be aware that this does not work for fused functions, such as `torch.nn.LayerNorm`
and `torch.nn.MultiheadAttention`. However, unfused functions which directly access a parameter, such as multiplication
and work well. Therefore, torch.nn.Linear and HuggingFace's attention work as expected.

```PYTHON
import transformers
Expand All @@ -40,11 +89,13 @@ for sample in ["Hello", "World", "!"]:

### nn

Patching existing PyTorch computation graphs on the fly might add unnecessary memory and computation. That's why a
pre-patched alternative of `torch.nn` with hand-crafted gradients exists alongside the `truegrad.utils` module. Compared
to `truegrad.utils.patch_model()`, `truegrad.nn` offers higher speeds and lower memory usage, although it might require
code alterations and doesn't support all models. You cannot (currently) use `truegrad.nn` with `truegrad.utils`, as both
use different ways to arrive at the same value.
Patching existing PyTorch computation graphs on the fly might add unnecessary memory and computation or even fail
unexpectedly. That's why a pre-patched alternative of `torch.nn` with hand-crafted gradients exists alongside the
`truegrad.utils` module. Compared to `truegrad.utils.patch_model()`, `truegrad.nn` offers higher speeds and lower
memory usage, although it might require code alterations and doesn't support all models. You cannot (currently) use
`truegrad.nn` with `truegrad.utils`, as both use different ways to arrive at the same value. However, you can
combine `torch.nn.Modules` and `truegrad.nn.Modules` and use the truegrad information only where it is available (
see [Partial TrueGrad](#Partial-TrueGrad)).

```PYTHON
import torch
Expand All @@ -58,7 +109,7 @@ model = torch.nn.Sequential(nn.Linear(1, 10),
nn.Linear(10, 1))
optim = TGAdamW(model.parameters()) # truegrad.optim.TGAdamW instead of torch.optim.AdamW

# training loop as normal
# standard training loop
while True:
input = torch.randn((16, 1))
model(input).mean().backward()
Expand All @@ -67,10 +118,11 @@ while True:

### Partial TrueGrad

Unfortunately, it's not always sensible to apply TrueGrad, as some backward passes are too slow to do them twice.
Therefore, it can be an option to use TGAdamW only on specific subsections of the model. To do so, you can either check
which parameters are of type `truegrad.nn.TrueGradParameter` when using `truegrad.utils.patch_model()` or which
parameters belong to a module listed in `truegrad.nn.modules`.
Unfortunately, it's not always sensible to apply TrueGrad, as some backward passes are too slow, and sometimes it's
impossible to avoid a fused function.
Therefore, it can be an option to use TGAdamW only on specific subsections of the model. To do so, you can
specify `default_to_adam=True` to TGAdamW. Adding this option allows TGAdamW to fall back to AdamW if there is
no `sum_grad_squared` attribute available.
For example, the code from [#nn](#nn) could be extended in the following way:

```PYTHON
Expand All @@ -83,26 +135,11 @@ model = torch.nn.Sequential(nn.Linear(1, 10), # Weights coming from truegrad.nn
torch.nn.ReLU(),
torch.nn.Linear(10, 1)) # Weights coming torch.nn

truegrad_parameters = []
normal_parameters = []


def get_parameters(mod: torch.nn.Module):
if isinstance(mod, nn.modules):
truegrad_parameters.extend(list(mod.parameters(recurse=False)))
else:
# you could do truegrad.utils.patch_model(mod, recurse=False) here!
normal_parameters.extend(list(mod.parameters(recurse=False)))


model = model.apply(get_parameters)

optim0 = TGAdamW(truegrad_parameters)
optim1 = torch.optim.AdamW(normal_parameters)
optim = TGAdamW(model.parameters(), default_to_adam=True)

# standard training loop
while True:
input = torch.randn((16, 1))
model(input).mean().backward()
optim0.step() # update both parameter sets separately
optim1.step()
optim.step()
```
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
name='truegrad',
license='BSD',
description='PyTorch interface for TrueGrad-AdamW',
version='0.0.9',
version='0.1.0',
long_description=README,
url='https://github.com/clashluke/truegrad',
packages=setuptools.find_packages(),
Expand Down
24 changes: 12 additions & 12 deletions truegrad/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def backward(ctx, dy: torch.Tensor):
diff = inp.ndim - weight.ndim
summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1]
weight_grad = dy * inp
weight.square_grad = weight_grad.square()
weight.sum_grad_squared = weight_grad.square()
if summed:
weight_grad = weight_grad.sum(summed)
weight.square_grad = weight.square_grad.sum(summed)
weight.square_grad = weight.square_grad.reshape(weight.size()) * dy.size(0)
weight.sum_grad_squared = weight.sum_grad_squared.sum(summed)
weight.sum_grad_squared = weight.sum_grad_squared.reshape(weight.size()) * dy.size(0)
return dy * weight, weight_grad.reshape(weight.size())


Expand All @@ -42,11 +42,11 @@ def backward(ctx, dy: torch.Tensor):
return None, None
weight, = ctx.saved_tensors
weight_grad = dy
weight.square_grad = dy.square()
weight.sum_grad_squared = dy.square()
if ctx.summed:
weight_grad = weight_grad.sum(ctx.summed)
weight.square_grad = weight.square_grad.sum(ctx.summed)
weight.square_grad = weight.square_grad.reshape(weight.size()) * dy.size(0)
weight.sum_grad_squared = weight.sum_grad_squared.sum(ctx.summed)
weight.sum_grad_squared = weight.sum_grad_squared.reshape(weight.size()) * dy.size(0)
return dy, weight_grad.reshape(weight.size())


Expand All @@ -67,7 +67,7 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor, torch.Tensor]:
lhs, rhs = inputs.split(',')

d_wgt = torch.einsum(f'{lhs},{output}->{rhs}', inp, dy)
wgt.square_grad = torch.einsum(f'{lhs},{output}->{rhs}', inp.square(), dy.square() * inp.size(0))
wgt.sum_grad_squared = torch.einsum(f'{lhs},{output}->{rhs}', inp.square(), dy.square() * inp.size(0))
d_inp = torch.einsum(f"{rhs},{output}->{lhs}", wgt, dy)
return None, d_inp, d_wgt

Expand All @@ -85,7 +85,7 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
return None, None
inp, wgt = ctx.saved_tensors
wgt_grad = torch.zeros_like(wgt)
wgt.square_grad = wgt_grad.scatter_add(0, inp, dy.square())
wgt.sum_grad_squared = wgt_grad.scatter_add(0, inp, dy.square())
wgt_grad.scatter_add_(0, inp, dy)
return None, wgt_grad

Expand All @@ -103,8 +103,8 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
if not ctx.saved_tensors:
return None
wgt, = ctx.saved_tensors
if hasattr(wgt, "square_grad"):
wgt.square_grad = wgt.square_grad.reshape(ctx.original_shape)
if hasattr(wgt, "sum_grad_squared"):
wgt.sum_grad_squared = wgt.sum_grad_squared.reshape(ctx.original_shape)
return dy.reshape(ctx.original_shape)


Expand All @@ -121,8 +121,8 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
if not ctx.saved_tensors:
return None
wgt, = ctx.saved_tensors
if hasattr(wgt, "square_grad") and ctx.summed:
wgt.square_grad = wgt.square_grad.sum(ctx.summed)
if hasattr(wgt, "sum_grad_squared") and ctx.summed:
wgt.sum_grad_squared = wgt.sum_grad_squared.sum(ctx.summed)
return dy.sum(ctx.summed)


Expand Down
6 changes: 3 additions & 3 deletions truegrad/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def _square(x: Union[torch.Tensor, None]):
for p, a in zip(list(ctx.args) + list(ctx.kwargs.values()), list(args) + list(kwargs.values())):
if not isinstance(p, torch.nn.Parameter):
continue
if hasattr(p, "square_grad") and p.square_grad is not None:
p.square_grad = p.square_grad + a.grad
if hasattr(p, "sum_grad_squared") and p.sum_grad_squared is not None:
p.sum_grad_squared = p.sum_grad_squared + a.grad
else:
p.square_grad = a.grad
p.sum_grad_squared = a.grad
return None, None, None, None


Expand Down
36 changes: 21 additions & 15 deletions truegrad/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ def __init__(self, params, lr: float = 1e-3,
eps: float = 1e-12,
weight_decay: float = 1e-2,
graft: bool = True,
decay_to_init: bool = False):
decay_to_init: bool = False,
default_to_adam: bool = False):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,
decay_to_init=decay_to_init)
decay_to_init=decay_to_init, default_to_adam=default_to_adam)
super(TGAdamW, self).__init__(params, defaults)

@torch.no_grad()
Expand All @@ -31,18 +32,19 @@ def step(self, closure=None):
for p in group['params']:
if p.grad is None:
continue
if not hasattr(p, "square_grad") or p.square_grad is None:
raise ValueError(f"Parameter of shape {list(p.size())} doesn't have `square_grad` attribute. "
f"Make sure to use truegrad.utils.patch_model() or truegrad.nn for all optimized "
f"parameters.")
do_adam = not hasattr(p, "sum_grad_squared") or p.sum_grad_squared is None
if not group["default_to_adam"] and do_adam:
raise ValueError(f"Parameter of shape {list(p.size())} doesn't have `sum_grad_squared` attribute. "
f"Make sure to use backpack.")

state = self.state[p]

if len(state) == 0:
state['step'] = torch.tensor(0.)
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['exp_avg_true_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group["graft"]:
if not do_adam:
state['exp_avg_true_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if do_adam or group["graft"]:
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group["decay_to_init"]:
state["init"] = torch.clone(p.detach())
Expand All @@ -61,22 +63,26 @@ def step(self, closure=None):
else:
p.mul_(1 - decay)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
exp_avg_true_sq.mul_(beta3).add_(p.square_grad, alpha=1 - beta3)
p.square_grad = None

step = step_t.item()

denom = (exp_avg_true_sq / (1 - beta3 ** step)).sqrt().add_(group['eps'])
update = exp_avg / denom
alpha = -group['lr'] / (1 - beta1 ** step)

if group["graft"]:
if not do_adam:
exp_avg_true_sq.mul_(beta3).add_(p.sum_grad_squared, alpha=1 - beta3)
p.sum_grad_squared = None
denom = (exp_avg_true_sq / (1 - beta3 ** step)).sqrt().add_(group['eps'])
update = exp_avg / denom

if group["graft"] or do_adam:
exp_avg_sq = state['exp_avg_sq']
exp_avg_sq.mul_(beta2).add_(p.grad.square(), alpha=1 - beta2)
adam_update = exp_avg / (exp_avg_sq / (1 - beta2 ** step)).sqrt().add_(group['eps'])

if group["graft"] and not do_adam:
alpha = alpha * adam_update.norm() / update.norm()
elif do_adam:
update = adam_update

p.add_(update, alpha=alpha)
return loss

0 comments on commit 0a14e38

Please sign in to comment.