Skip to content

Commit

Permalink
feat(functional): improve grad accum, fix einsum backwd, allow full p…
Browse files Browse the repository at this point in the history
…atching
  • Loading branch information
ClashLuke committed Nov 29, 2022
1 parent bb868cc commit 4279d61
Show file tree
Hide file tree
Showing 7 changed files with 428 additions and 286 deletions.
126 changes: 111 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ python3 -m pip install truegrad

TrueGrad supports various backends, each with their own tradeoffs:

| Name | Advantages | Disadvantages |
|----------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
| [truegrad.nn](#nn) | * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported<br/>* Custom forward/backward for some fused functions<br/>* Optimized backward passes | * Limited applicability - custom modules can't be used<br/>* Requires code modification |
| [truegrad.utils.patch_torch](#patch-torch) | * Uses truegrad.nn under the hood<br/>* Works for many (off-the-shelf!) torch models<br/>* No code modification necessary | * Uncertainty if model is compatible |
| [backpack](#backpack) | * Highest stability<br/>* Loud warnings and errors<br/>* Battle-tested<br/>* Simple to extend further | * High memory usage<br/>* High compute usage<br/>* Sparse support for torch operations |
| [truegrad.utils.patch_model](#patch-custom-models) | * Best compatibility | * Fails silently on fused functions<br/>* More costly than truegrad.nn |
| Name | Advantages | Disadvantages |
|----------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|
| [truegrad.nn](#nn) | * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported<br/>* Custom forward/backward for some fused functions<br/>* Optimized backward passes | * Limited applicability - custom modules can't be used<br/>* Requires code modification |
| [truegrad.utils.patch_torch](#patch-torch) | * Uses truegrad.nn under the hood<br/>* Works for many (off-the-shelf!) torch models<br/>* No code modification necessary | * Uncertainty if model is compatible |
| [backpack](#backpack) | * Highest stability<br/>* Loud warnings and errors<br/>* Battle-tested<br/>* Simple to extend further | * High memory usage<br/>* High compute usage<br/>* Sparse support for torch operations |
| [truegrad.utils.patch_model](#patch-custom-models) | * Works with custom models | * Fails silently on fused functions<br/>* ~50% to 100% slower than truegrad.nn |
| [patch_torch + patch_model](#Full Patching) | * Best compatibility<br/>* Reduced overheads compared to `patch_model` (by falling back to faster pre-patched `patch_torch` where available) | * Fails silently on fused functions outside of torch.nn<br/> * Slower than truegrad.nn when truegrad.nn would've been enough |

Below, you'll find examples for each of these backends, as well as a [general strategy](#partial-truegrad) allowing
partial application of TrueGrad.
Expand Down Expand Up @@ -47,6 +48,7 @@ while True:
input = torch.randn((16, 1))
model(input).mean().backward()
optim.step()
optim.zero_grad()
```

### Patch Torch
Expand Down Expand Up @@ -77,11 +79,46 @@ while True:
loss = torch.nn.functional.cross_entropy(model(inp), tgt)
loss.backward()
optim.step()
optim.zero_grad()
i += 1
if i % 5 == 0:
print(i, loss.item())
```

Similarly, most huggingface transformers work out of the box:

```PYTHON
import torch
import transformers
from torch.nn import functional as F

from truegrad.optim import TGAdamW
from truegrad.utils import patch_torch

patch_torch() # only added line to get truegrad statistics for TGAdamW

model = transformers.BertModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2") # any existing model
tokenizer = transformers.BertTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2")

optim = TGAdamW(model.parameters())

# constant input to overfit
input = tokenizer(["Hello World!"], return_tensors="pt")

# training loop as normal
while True:
out = model(**input)
loss = F.l1_loss(out[0], torch.ones_like(out[0]))
loss.backward()
optim.step()
optim.zero_grad()
print(loss.item())
```

Note that this works even though transformers have custom modules, which could cause issues. The key factor is that all
parameters come from `torch.nn.Module`'s, which are patched by `patch_torch()`. Therefore, truegrad handles all
parameter usages. Therefore, any composition of `torch.nn.Module`'s makes for a truegrad-compatible model.

### BackPack

The most stable although also memory hungry method to compute TrueGrad statistics is to use
Expand Down Expand Up @@ -119,6 +156,7 @@ while True:
loss = lossfunc(model(inp), tgt)
loss.backward()
optim.step()
optim.zero_grad()
i += 1
if i % 5 == 0:
print(i, loss.item())
Expand All @@ -141,21 +179,78 @@ and `torch.nn.MultiheadAttention`. However, unfused functions which directly acc
work well. Therefore, torch.nn.Linear and HuggingFace's attention work as expected.

```PYTHON
import transformers
from truegrad.utils import patch_model
import torch
from truegrad.optim import TGAdamW
from truegrad.utils import patch_model
from torchvision.models import alexnet

model = transformers.BertModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2") # any existing model
tokenizer = transformers.BertTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
model = alexnet() # patch_model can't handle fused ops like VGG's and ResNet's BatchNorm
optim = TGAdamW(model.parameters())

# replace inplace ops like nn.ReLU(inplace=True) where possible
for mod in model.modules():
if hasattr(mod, "inplace"):
mod.inplace = False

patch_model(model) # replace torch.nn.Parameter with truegrad.nn.Parameter
optim = TGAdamW(model.parameters()) # truegrad.optim.TGAdamW instead of torch.optim.AdamW

# training loop as normal
for sample in ["Hello", "World", "!"]:
out = model(**tokenizer([sample], return_tensors="pt"))
out[0].mean().backward()
# constant input/output to overfit
inp = torch.randn((2, 3, 224, 224))
tgt = torch.randint(0, 1000, (2,))

# standard training loop
i = 0
while True:
# "SumGradSquared" computes the sum of the squared gradient
loss = torch.nn.functional.cross_entropy(model(inp), tgt)
loss.backward()
optim.step()
optim.zero_grad()
i += 1
if i % 5 == 0:
print(i, loss.item())
```

### Full Patching

One way of avoiding [truegrad.utils.patch_model](#patch-custom-models)'s downsides when working with off-the-shelf
models containing custom parameters, such as [lucidrains' ViT's](https://github.com/lucidrains/vit-pytorch/) is to also
`patch_torch`. This takes care of many fused functions, such as LayerNorm, while still allowing full flexibility in
model design.

```PYTHON
import torch
from vit_pytorch.levit import LeViT
from truegrad.utils import patch_torch, patch_model
from truegrad.optim import TGAdamW

patch_torch() # before model instantiation

levit = LeViT(
image_size=224,
num_classes=1000,
stages=3, # number of stages
dim=(256, 384, 512), # dimensions at each stage
depth=4, # transformer of depth 4 at each stage
heads=(4, 6, 8), # heads at each stage
mlp_mult=2,
dropout=0.1
)

opt = TGAdamW(levit.parameters())

patch_model(levit) # replace torch.nn.Parameter with truegrad.nn.TrueGradParameter

# constant input to overfit
img = torch.randn(1, 3, 224, 224)

# standard training loop
while True:
loss = levit(img).square().mean()
loss.backward()
opt.step()
opt.zero_grad()
print(loss.item())
```

### Partial TrueGrad
Expand Down Expand Up @@ -186,6 +281,7 @@ while True:
loss = model(input).mean()
loss.backward()
optim.step()
optim.zero_grad()
i += 1
if i % 5 == 0:
print(i, loss.item())
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
name='truegrad',
license='BSD',
description='PyTorch interface for TrueGrad-AdamW',
version='2.0.0',
version='2.1.0',
long_description=README,
url='https://github.com/clashluke/truegrad',
packages=setuptools.find_packages(),
Expand Down
Loading

0 comments on commit 4279d61

Please sign in to comment.