From 4279d61f40e2348a5e622408ddb99e350502f69a Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Tue, 29 Nov 2022 08:27:46 +0100
Subject: [PATCH] feat(functional): improve grad accum, fix einsum backwd,
allow full patching
---
README.md | 126 ++++++++++--
setup.py | 2 +-
truegrad/functional.py | 393 +++++++++++++++++++++++++-------------
truegrad/nn/__init__.py | 79 ++------
truegrad/nn/functional.py | 85 ++-------
truegrad/optim.py | 2 +-
truegrad/utils.py | 27 ++-
7 files changed, 428 insertions(+), 286 deletions(-)
diff --git a/README.md b/README.md
index e782201..0cd4cb8 100644
--- a/README.md
+++ b/README.md
@@ -14,12 +14,13 @@ python3 -m pip install truegrad
TrueGrad supports various backends, each with their own tradeoffs:
-| Name | Advantages | Disadvantages |
-|----------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
-| [truegrad.nn](#nn) | * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported
* Custom forward/backward for some fused functions
* Optimized backward passes | * Limited applicability - custom modules can't be used
* Requires code modification |
-| [truegrad.utils.patch_torch](#patch-torch) | * Uses truegrad.nn under the hood
* Works for many (off-the-shelf!) torch models
* No code modification necessary | * Uncertainty if model is compatible |
-| [backpack](#backpack) | * Highest stability
* Loud warnings and errors
* Battle-tested
* Simple to extend further | * High memory usage
* High compute usage
* Sparse support for torch operations |
-| [truegrad.utils.patch_model](#patch-custom-models) | * Best compatibility | * Fails silently on fused functions
* More costly than truegrad.nn |
+| Name | Advantages | Disadvantages |
+|----------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|
+| [truegrad.nn](#nn) | * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported
* Custom forward/backward for some fused functions
* Optimized backward passes | * Limited applicability - custom modules can't be used
* Requires code modification |
+| [truegrad.utils.patch_torch](#patch-torch) | * Uses truegrad.nn under the hood
* Works for many (off-the-shelf!) torch models
* No code modification necessary | * Uncertainty if model is compatible |
+| [backpack](#backpack) | * Highest stability
* Loud warnings and errors
* Battle-tested
* Simple to extend further | * High memory usage
* High compute usage
* Sparse support for torch operations |
+| [truegrad.utils.patch_model](#patch-custom-models) | * Works with custom models | * Fails silently on fused functions
* ~50% to 100% slower than truegrad.nn |
+| [patch_torch + patch_model](#Full Patching) | * Best compatibility
* Reduced overheads compared to `patch_model` (by falling back to faster pre-patched `patch_torch` where available) | * Fails silently on fused functions outside of torch.nn
* Slower than truegrad.nn when truegrad.nn would've been enough |
Below, you'll find examples for each of these backends, as well as a [general strategy](#partial-truegrad) allowing
partial application of TrueGrad.
@@ -47,6 +48,7 @@ while True:
input = torch.randn((16, 1))
model(input).mean().backward()
optim.step()
+ optim.zero_grad()
```
### Patch Torch
@@ -77,11 +79,46 @@ while True:
loss = torch.nn.functional.cross_entropy(model(inp), tgt)
loss.backward()
optim.step()
+ optim.zero_grad()
i += 1
if i % 5 == 0:
print(i, loss.item())
```
+Similarly, most huggingface transformers work out of the box:
+
+```PYTHON
+import torch
+import transformers
+from torch.nn import functional as F
+
+from truegrad.optim import TGAdamW
+from truegrad.utils import patch_torch
+
+patch_torch() # only added line to get truegrad statistics for TGAdamW
+
+model = transformers.BertModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2") # any existing model
+tokenizer = transformers.BertTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
+
+optim = TGAdamW(model.parameters())
+
+# constant input to overfit
+input = tokenizer(["Hello World!"], return_tensors="pt")
+
+# training loop as normal
+while True:
+ out = model(**input)
+ loss = F.l1_loss(out[0], torch.ones_like(out[0]))
+ loss.backward()
+ optim.step()
+ optim.zero_grad()
+ print(loss.item())
+```
+
+Note that this works even though transformers have custom modules, which could cause issues. The key factor is that all
+parameters come from `torch.nn.Module`'s, which are patched by `patch_torch()`. Therefore, truegrad handles all
+parameter usages. Therefore, any composition of `torch.nn.Module`'s makes for a truegrad-compatible model.
+
### BackPack
The most stable although also memory hungry method to compute TrueGrad statistics is to use
@@ -119,6 +156,7 @@ while True:
loss = lossfunc(model(inp), tgt)
loss.backward()
optim.step()
+ optim.zero_grad()
i += 1
if i % 5 == 0:
print(i, loss.item())
@@ -141,21 +179,78 @@ and `torch.nn.MultiheadAttention`. However, unfused functions which directly acc
work well. Therefore, torch.nn.Linear and HuggingFace's attention work as expected.
```PYTHON
-import transformers
-from truegrad.utils import patch_model
+import torch
from truegrad.optim import TGAdamW
+from truegrad.utils import patch_model
+from torchvision.models import alexnet
-model = transformers.BertModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2") # any existing model
-tokenizer = transformers.BertTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
+model = alexnet() # patch_model can't handle fused ops like VGG's and ResNet's BatchNorm
+optim = TGAdamW(model.parameters())
+
+# replace inplace ops like nn.ReLU(inplace=True) where possible
+for mod in model.modules():
+ if hasattr(mod, "inplace"):
+ mod.inplace = False
patch_model(model) # replace torch.nn.Parameter with truegrad.nn.Parameter
-optim = TGAdamW(model.parameters()) # truegrad.optim.TGAdamW instead of torch.optim.AdamW
-# training loop as normal
-for sample in ["Hello", "World", "!"]:
- out = model(**tokenizer([sample], return_tensors="pt"))
- out[0].mean().backward()
+# constant input/output to overfit
+inp = torch.randn((2, 3, 224, 224))
+tgt = torch.randint(0, 1000, (2,))
+
+# standard training loop
+i = 0
+while True:
+ # "SumGradSquared" computes the sum of the squared gradient
+ loss = torch.nn.functional.cross_entropy(model(inp), tgt)
+ loss.backward()
optim.step()
+ optim.zero_grad()
+ i += 1
+ if i % 5 == 0:
+ print(i, loss.item())
+```
+
+### Full Patching
+
+One way of avoiding [truegrad.utils.patch_model](#patch-custom-models)'s downsides when working with off-the-shelf
+models containing custom parameters, such as [lucidrains' ViT's](https://github.com/lucidrains/vit-pytorch/) is to also
+`patch_torch`. This takes care of many fused functions, such as LayerNorm, while still allowing full flexibility in
+model design.
+
+```PYTHON
+import torch
+from vit_pytorch.levit import LeViT
+from truegrad.utils import patch_torch, patch_model
+from truegrad.optim import TGAdamW
+
+patch_torch() # before model instantiation
+
+levit = LeViT(
+ image_size=224,
+ num_classes=1000,
+ stages=3, # number of stages
+ dim=(256, 384, 512), # dimensions at each stage
+ depth=4, # transformer of depth 4 at each stage
+ heads=(4, 6, 8), # heads at each stage
+ mlp_mult=2,
+ dropout=0.1
+ )
+
+opt = TGAdamW(levit.parameters())
+
+patch_model(levit) # replace torch.nn.Parameter with truegrad.nn.TrueGradParameter
+
+# constant input to overfit
+img = torch.randn(1, 3, 224, 224)
+
+# standard training loop
+while True:
+ loss = levit(img).square().mean()
+ loss.backward()
+ opt.step()
+ opt.zero_grad()
+ print(loss.item())
```
### Partial TrueGrad
@@ -186,6 +281,7 @@ while True:
loss = model(input).mean()
loss.backward()
optim.step()
+ optim.zero_grad()
i += 1
if i % 5 == 0:
print(i, loss.item())
diff --git a/setup.py b/setup.py
index ebc72b5..01d6d7e 100644
--- a/setup.py
+++ b/setup.py
@@ -10,7 +10,7 @@
name='truegrad',
license='BSD',
description='PyTorch interface for TrueGrad-AdamW',
- version='2.0.0',
+ version='2.1.0',
long_description=README,
url='https://github.com/clashluke/truegrad',
packages=setuptools.find_packages(),
diff --git a/truegrad/functional.py b/truegrad/functional.py
index f2a2141..c28e751 100644
--- a/truegrad/functional.py
+++ b/truegrad/functional.py
@@ -1,31 +1,97 @@
+import contextlib
import typing
-from typing import Any, Callable, List, Tuple
+from typing import Any, Callable, List, Optional, Tuple
import torch
+from torch import Tensor, nn
+from torch.nn import functional as F, grad
from torch.utils._pytree import tree_map
-def _unpack(x: Any) -> Any:
+# TrueGradParameter
+
+
+def is_tgparam(param: nn.Parameter):
+ if isinstance(param, TrueGradParameter):
+ return True
+ if isinstance(param, nn.Parameter) and hasattr(param, "activated"):
+ return True
+ return False
+
+
+def unpack_tg_param(x: Any) -> Any:
+ if is_tgparam(x) and not x.activated:
+ x.activated = True
+ return x
+
+
+@contextlib.contextmanager
+def activate_tg_params(*tensors):
+ for t in tensors:
+ unpack_tg_param(t)
+ yield
+ for t in tensors:
+ if is_tgparam(t):
+ t.activated = False
+
+
+_parameter_function = nn.Parameter.__torch_function__
+
+
+class TrueGradParameter(nn.Parameter):
+ activated: bool
+
+ @staticmethod
+ def __new__(cls, data=None, requires_grad=True):
+ if data is None:
+ data = torch.zeros(())
+ out = torch.nn.Parameter._make_subclass(cls, data, requires_grad)
+ out.activated = False
+ return out
+
+ def __repr__(self):
+ return f"TrueGradParameter({self.data})"
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+
+ def base(a, k):
+ return _parameter_function(func, types, a, k)
+
+ if all(not is_tgparam(a) or a.activated for a in list(args) + list(kwargs.values())):
+ return base(args, kwargs)
+ with activate_tg_params(*args, *kwargs.values()):
+ out = base(tree_map(unpack_tg_param, args), tree_map(unpack_tg_param, kwargs))
+ if not isinstance(out, Tensor):
+ return out
+ return wrap(base, out, args, kwargs)
+
+
+# TrueGradTensor
+
+def unpack_tg_tensor(x: Any) -> Any:
if isinstance(x, TrueGradTensor):
return x.data
return x
-_base_torch_function = torch.Tensor.__torch_function__
+_tensor_function = Tensor.__torch_function__
-class TrueGradTensor(torch.Tensor):
- sum_grad_squared: torch.Tensor
- data: torch.Tensor
+class TrueGradTensor(Tensor):
+ sum_grad_squared: Tensor
+ data: Tensor
requires_grad: bool
__slots__ = ['sum_grad_squared', "data", "requires_grad"]
@staticmethod
- def __new__(cls, data: torch.Tensor):
+ def __new__(cls, data: Tensor):
meta = data.new_empty((0,))
meta.set_(meta.storage(), 0, data.size(), data.stride())
- r = torch.Tensor._make_subclass(cls, meta, data.requires_grad)
+ r = Tensor._make_subclass(cls, meta, data.requires_grad)
r.data = data
r.sum_grad_squared = None
r.activated = False
@@ -39,222 +105,253 @@ def __repr__(self):
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
- out = _base_torch_function(func, [], tree_map(_unpack, args), tree_map(_unpack, kwargs))
- return out
+ return _tensor_function(func, [], tree_map(unpack_tg_tensor, args), tree_map(unpack_tg_tensor, kwargs))
+# utils
+
+def valid_attr(wgt: nn.Parameter, attr: str = "sum_grad_squared"):
+ return hasattr(wgt, attr) and getattr(wgt, attr) is not None
+
+
+def add_or_set(wgt: nn.Parameter, new: torch.Tensor, attr: str = "sum_grad_squared"):
+ if hasattr(wgt, attr) and getattr(wgt, attr) is not None:
+ new = getattr(wgt, attr) + new
+ setattr(wgt, attr, new)
+
+
+# Autograd Functions
+
class MulFn(torch.autograd.Function):
@staticmethod
- def forward(ctx, inp: torch.Tensor, weight: torch.Tensor):
- if weight.requires_grad:
- ctx.save_for_backward(inp)
- ctx.weight = weight
- return inp * weight
+ def forward(ctx, inp: Tensor, weight: Tensor):
+ with activate_tg_params(inp, weight):
+ if weight.requires_grad:
+ ctx.save_for_backward(inp)
+ ctx.weight = weight
+ return inp * weight
@staticmethod
- def backward(ctx, dy: torch.Tensor):
+ def backward(ctx, dy: Tensor):
if not ctx.saved_tensors:
return None, None
inp, = ctx.saved_tensors
weight = ctx.weight
- diff = inp.ndim - weight.ndim
- summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1]
- weight_grad = dy * inp
- weight.sum_grad_squared = weight_grad.square()
- if summed:
- weight_grad = weight_grad.sum(summed)
- weight.sum_grad_squared = weight.sum_grad_squared.sum(summed)
- weight.sum_grad_squared = weight.sum_grad_squared.reshape(weight.size()) * dy.size(0)
+ with activate_tg_params(inp, weight):
+ diff = inp.ndim - weight.ndim
+ summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1]
+ weight_grad = dy * inp
+ sum_grad_squared = weight_grad.square()
+ if summed:
+ weight_grad = weight_grad.sum(summed)
+ sum_grad_squared = sum_grad_squared.sum(summed)
+ add_or_set(weight, sum_grad_squared.reshape(weight.size()) * dy.size(0))
return dy * weight, weight_grad.reshape(weight.size())
class AddFn(torch.autograd.Function):
@staticmethod
- def forward(ctx, inp: torch.Tensor, weight: torch.Tensor):
- if weight.requires_grad:
- diff = inp.ndim - weight.ndim
- ctx.summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1]
- ctx.batch_size = inp.size(0)
- ctx.weight = weight
+ def forward(ctx, inp: Tensor, weight: Tensor):
+ with activate_tg_params(inp, weight):
+ if weight.requires_grad:
+ diff = inp.ndim - weight.ndim
+ ctx.summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1]
+ ctx.batch_size = inp.size(0)
+ ctx.weight = weight
return inp + weight
@staticmethod
- def backward(ctx, dy: torch.Tensor):
+ def backward(ctx, dy: Tensor):
if not hasattr(ctx, "weight"):
return None, None
weight = ctx.weight
- weight_grad = dy
- weight.sum_grad_squared = dy.square()
- if ctx.summed:
- weight_grad = weight_grad.sum(ctx.summed)
- weight.sum_grad_squared = weight.sum_grad_squared.sum(ctx.summed)
- weight.sum_grad_squared = weight.sum_grad_squared.reshape(weight.size()) * dy.size(0)
+ with activate_tg_params(weight):
+ weight_grad = dy
+ sum_grad_squared = dy.square()
+ if ctx.summed:
+ weight_grad = weight_grad.sum(ctx.summed)
+ sum_grad_squared = sum_grad_squared.sum(ctx.summed)
+ add_or_set(weight, sum_grad_squared.reshape(weight.size()) * dy.size(0))
return dy, weight_grad.reshape(weight.size())
class EinsumFn(torch.autograd.Function):
@staticmethod
- def forward(ctx, spec: str, inp: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
- if weight.requires_grad:
- ctx.save_for_backward(inp)
- ctx.weight = weight
- ctx.spec = spec
- return torch.einsum(spec, inp, weight)
+ def forward(ctx, spec: str, inp: Tensor, weight: Tensor) -> Tensor:
+ with activate_tg_params(inp, weight):
+ if weight.requires_grad:
+ ctx.save_for_backward(inp, weight)
+ ctx.spec = spec
+ return torch.clone(torch.einsum(spec, inp, weight))
@staticmethod
- def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor, torch.Tensor]:
+ def backward(ctx, dy: Tensor) -> Tuple[None, Tensor, Tensor]:
if not ctx.saved_tensors:
return None, None, None
- inp, = ctx.saved_tensors
- wgt = ctx.weight
- inputs, output = ctx.spec.split('->')
- lhs, rhs = inputs.split(',')
+ inp, wgt = ctx.saved_tensors
+ with activate_tg_params(inp, wgt):
+ inputs, output = ctx.spec.split('->')
+ lhs, rhs = inputs.split(',')
+
+ d_wgt = torch.einsum(f'{lhs},{output}->{rhs}', inp, dy).contiguous()
+ add_or_set(wgt, torch.einsum(f'{lhs},{output}->{rhs}', inp.square(), dy.square()).contiguous())
+ d_inp = torch.einsum(f"{rhs},{output}->{lhs}", wgt, dy).contiguous()
- d_wgt = torch.einsum(f'{lhs},{output}->{rhs}', inp, dy)
- wgt.sum_grad_squared = torch.einsum(f'{lhs},{output}->{rhs}', inp.square(), dy.square() * inp.size(0))
- d_inp = torch.einsum(f"{rhs},{output}->{lhs}", wgt, dy)
return None, d_inp, d_wgt
class GatherFn(torch.autograd.Function):
@staticmethod
- def forward(ctx, inp: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
- if weight.requires_grad:
- ctx.save_for_backward(inp)
- ctx.weight = weight
+ def forward(ctx, inp: Tensor, weight: Tensor) -> Tensor:
+ with activate_tg_params(inp, weight):
+ if weight.requires_grad:
+ ctx.save_for_backward(inp)
+ ctx.weight = weight
return torch.gather(weight, 0, inp)
@staticmethod
- def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
+ def backward(ctx, dy: Tensor) -> Tuple[None, Tensor]:
if not ctx.saved_tensors:
return None, None
inp, = ctx.saved_tensors
wgt = ctx.weight
- wgt_grad = torch.zeros_like(wgt)
- wgt.sum_grad_squared = wgt_grad.scatter_add(0, inp, dy.square())
- wgt_grad.scatter_add_(0, inp, dy)
+ with activate_tg_params(inp, wgt):
+ wgt_grad = torch.zeros_like(wgt)
+ add_or_set(wgt, wgt_grad.scatter_add(0, inp, dy.square()))
+ wgt_grad.scatter_add_(0, inp, dy)
return None, wgt_grad
class ReshapeFn(torch.autograd.Function):
@staticmethod
- def forward(ctx, weight: torch.Tensor, new_shape: List[int]) -> torch.Tensor:
- out = TrueGradTensor(weight.reshape(new_shape).detach().requires_grad_(True))
- if weight.requires_grad:
- ctx.save_for_backward(weight)
- ctx.out = out
- ctx.original_shape = weight.size()
- return out
+ def forward(ctx, weight: Tensor, new_shape: List[int]) -> Tensor:
+ with activate_tg_params(weight):
+ out = TrueGradTensor(weight.reshape(new_shape).detach().requires_grad_(True))
+ if weight.requires_grad:
+ ctx.save_for_backward(weight, out)
+ return out
@staticmethod
- def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
+ def backward(ctx, dy: Tensor) -> Tuple[None, Tensor]:
if not ctx.saved_tensors:
return None, None
- wgt, = ctx.saved_tensors
- if ctx.out.sum_grad_squared is not None:
- wgt.sum_grad_squared = ctx.out.sum_grad_squared.reshape(ctx.original_shape)
- return dy.reshape(ctx.original_shape), None
+ wgt, out = ctx.saved_tensors
+ with activate_tg_params(wgt):
+ if valid_attr(out):
+ add_or_set(wgt, ctx.out.sum_grad_squared)
+ return dy.reshape(wgt.size()), None
class TransposeFn(torch.autograd.Function):
@staticmethod
- def forward(ctx, weight: torch.Tensor, dims: typing.List[int]) -> torch.Tensor:
- out = TrueGradTensor(weight.transpose(*dims).detach().requires_grad_(True))
- if weight.requires_grad:
- ctx.save_for_backward(weight)
- ctx.out = out
- ctx.dims = dims
+ def forward(ctx, weight: Tensor, dims: typing.List[int]) -> Tensor:
+ with activate_tg_params(weight):
+ out = TrueGradTensor(weight.transpose(*dims).detach().requires_grad_(True))
+ if weight.requires_grad:
+ ctx.save_for_backward(weight, out)
+ ctx.dims = dims
return out
@staticmethod
- def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
+ def backward(ctx, dy: Tensor) -> Tuple[None, Tensor]:
if not ctx.saved_tensors:
return None, None
- wgt, = ctx.saved_tensors
- if ctx.out.sum_grad_squared is not None:
- wgt.sum_grad_squared = ctx.out.sum_grad_squared.transpose(*ctx.dims)
+ wgt, out = ctx.saved_tensors
+ with activate_tg_params(wgt):
+ if valid_attr(out):
+ add_or_set(wgt, out.sum_grad_squared.transpose(*ctx.dims()))
return dy.transpose(*ctx.dims), None
class ChunkFn(torch.autograd.Function):
@staticmethod
- def forward(ctx, weight: torch.Tensor, chunks: int, dim: int):
- out = tuple(TrueGradTensor(c) for c in weight.chunk(chunks, dim))
- if weight.requires_grad:
- ctx.save_for_backward(weight)
- ctx.out = out
- ctx.dim = dim
+ def forward(ctx, weight: Tensor, chunks: int, dim: int):
+ with activate_tg_params(weight):
+ out = tuple(TrueGradTensor(c) for c in weight.chunk(chunks, dim))
+ if weight.requires_grad:
+ ctx.save_for_backward(weight, out)
+ ctx.dim = dim
return out
@staticmethod
- def backward(ctx, *dy: torch.Tensor):
+ def backward(ctx, *dy: Tensor):
if not ctx.saved_tensors:
return None, None, None
- wgt, = ctx.saved_tensors
- wgt.sum_grad_squared = torch.cat([o.sum_grad_squared for o in ctx.out], dim=ctx.dim)
+ wgt, out = ctx.saved_tensors
+ with activate_tg_params(wgt):
+ if all(valid_attr(o) for o in out):
+ add_or_set(wgt, torch.cat([o.sum_grad_squared for o in out], dim=ctx.dim))
return torch.cat(dy, dim=ctx.dim), None, None
class SplitFn(torch.autograd.Function):
@staticmethod
- def forward(ctx, weight: torch.Tensor, split_size: int, dim: int):
- out = tuple(TrueGradTensor(c) for c in weight.split(split_size, dim))
- if weight.requires_grad:
- ctx.save_for_backward(weight)
- ctx.out = out
- ctx.dim = dim
+ def forward(ctx, weight: Tensor, split_size: int, dim: int):
+ with activate_tg_params(weight):
+ out = tuple(TrueGradTensor(c) for c in weight.split(split_size, dim))
+ if weight.requires_grad:
+ ctx.save_for_backward(weight, out)
+ ctx.dim = dim
return out
@staticmethod
- def backward(ctx, *dy: torch.Tensor):
+ def backward(ctx, *dy: Tensor):
if not ctx.saved_tensors:
return None, None, None
- wgt, = ctx.saved_tensors
- wgt.sum_grad_squared = torch.cat([o.sum_grad_squared for o in ctx.out], dim=ctx.dim)
+ wgt, out = ctx.saved_tensors
+ with activate_tg_params(wgt):
+ if all(valid_attr(o) for o in out):
+ add_or_set(wgt, torch.cat([o.sum_grad_squared for o in out], dim=ctx.dim))
return torch.cat(dy, dim=ctx.dim), None, None
class ExpandFn(torch.autograd.Function):
@staticmethod
- def forward(ctx, weight: torch.Tensor, new_shape: List[int]) -> torch.Tensor:
- out = TrueGradTensor(weight.expand(new_shape))
- if weight.requires_grad:
- ctx.save_for_backward(weight)
- ctx.out = out
- ctx.summed = [i for i, d in enumerate(new_shape) if d != -1]
+ def forward(ctx, weight: Tensor, new_shape: List[int]) -> Tensor:
+ with activate_tg_params(weight):
+ out = TrueGradTensor(weight.expand(new_shape))
+ if weight.requires_grad:
+ ctx.save_for_backward(weight, out)
+ ctx.summed = [i for i, d in enumerate(new_shape) if d != -1]
return out
@staticmethod
- def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
+ def backward(ctx, dy: Tensor) -> Tuple[None, Tensor]:
if not ctx.saved_tensors:
return None, None
- wgt, = ctx.saved_tensors
- if ctx.out.sum_grad_squared is not None and ctx.summed:
- wgt.sum_grad_squared = ctx.out.sum_grad_squared.sum(ctx.summed)
- return dy.sum(ctx.summed)
+ wgt, out = ctx.saved_tensors
+ with activate_tg_params(wgt):
+ if valid_attr(out):
+ sum_grad_squared = out.sum_grad_squared
+ if ctx.summed:
+ sum_grad_squared = sum_grad_squared.sum(ctx.summed)
+ add_or_set(wgt, sum_grad_squared)
+ if ctx.summed:
+ return dy.sum(ctx.summed)
+ return dy
class WrapFn(torch.autograd.Function):
@staticmethod
- def forward(ctx, fn, args, kwargs) -> torch.Tensor:
+ def forward(ctx, fn, out, args, kwargs) -> Tensor:
ctx.fn = fn
ctx.args = args
ctx.kwargs = kwargs
- return fn(*args, **kwargs)
+ return out
@staticmethod
- def backward(ctx, dy: torch.Tensor) -> Tuple[None, None, None, None]:
- def _backward(fn: Callable[[torch.Tensor], torch.Tensor], attr: str):
- def _fn(x: torch.Tensor):
- if isinstance(x, torch.nn.Parameter):
- x = x.data
- if not isinstance(x, torch.Tensor) or not torch.is_floating_point(x):
- return x
- x = fn(x.detach())
- x.requires_grad_(True)
+ def backward(ctx, dy: Tensor) -> Tuple[None, Tensor, None, None]:
+ def _fn(x: Tensor):
+ if isinstance(x, nn.Parameter):
+ x = x.data
+ if not isinstance(x, Tensor) or not torch.is_floating_point(x):
return x
+ x = torch.square(x.detach()).detach()
+ x.requires_grad_(True)
+ return x
+ with activate_tg_params(*ctx.args, *ctx.kwargs.values()):
args = tree_map(_fn, ctx.args)
kwargs = tree_map(_fn, ctx.kwargs)
@@ -263,17 +360,46 @@ def _fn(x: torch.Tensor):
torch.autograd.backward(out, tree_map(_fn, dy))
for p, a in zip(list(ctx.args) + list(ctx.kwargs.values()), list(args) + list(kwargs.values())):
- if not isinstance(p, torch.nn.Parameter):
+ if not hasattr(a, "grad"):
continue
- if hasattr(p, attr) and getattr(p, attr) is not None:
- a.grad = getattr(p, attr) + a.grad
- setattr(p, attr, a.grad)
+ add_or_set(p, a.grad.contiguous())
- _backward(torch.square, "sum_grad_squared")
- _backward(lambda x: x, "grad")
+ return None, dy, None, None
- return None, None, None, None
+class ConvNdFn(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], args) -> Tensor:
+ with activate_tg_params(input, weight, bias):
+ if weight.requires_grad:
+ ctx.save_for_backward(input, weight, bias)
+ ctx.args = args
+ dim = input.dim() - 2 # Batch, Feature, *Data
+ return getattr(F, f"conv{dim}d")(input, weight, bias, *args).contiguous()
+
+ @staticmethod
+ def backward(ctx, dy: Tensor) -> Tuple[Tensor, Tensor, Optional[Tensor], None]:
+ if not ctx.saved_tensors:
+ return None, None, None, None
+ inp, wgt, bias = ctx.saved_tensors
+ with activate_tg_params(inp, wgt, bias):
+ dim = inp.dim() - 2
+ summed = [0] + list(range(2, 2 + dim))
+
+ dx = getattr(grad, f"conv{dim}d_input")(inp.size(), wgt, dy, *ctx.args)
+ dw = getattr(grad, f"conv{dim}d_weight")(inp, wgt.size(), dy, *ctx.args)
+ db = None if bias is None else dy.sum(summed)
+
+ if isinstance(wgt, nn.Parameter) or isinstance(bias, nn.Parameter):
+ dy_sq = dy.square() * dy.size(0)
+ if isinstance(wgt, nn.Parameter):
+ wgt.sum_grad_squared = getattr(grad, f"conv{dim}d_weight")(inp.square(), wgt.size(), dy_sq, *ctx.args)
+ if isinstance(bias, nn.Parameter):
+ bias.sum_grad_squared = dy_sq.sum(summed)
+ return dx, dw, db, None
+
+
+# "Normal" Functions
mul = MulFn.apply
add = AddFn.apply
@@ -285,8 +411,17 @@ def _fn(x: torch.Tensor):
split = SplitFn.apply
expand = ExpandFn.apply
wrap = WrapFn.apply
+convnd = ConvNdFn.apply
-def matmul(inp: torch.Tensor, wgt: torch.Tensor):
+def matmul(inp: Tensor, wgt: Tensor):
batch_dims = ''.join(chr(ord('a') + i) for i in range(inp.ndim - 1))
return einsum(f"{batch_dims}y,yz->{batch_dims}z", inp, wgt)
+
+
+def simple_wrap(function: Callable, out: torch.Tensor, *args, **kwargs):
+ def _fn(a, k):
+ return function(*a, **k)
+
+ _fn.abc = function.__name__
+ return wrap(_fn, out, args, kwargs)
diff --git a/truegrad/nn/__init__.py b/truegrad/nn/__init__.py
index 39f9e50..42fb272 100644
--- a/truegrad/nn/__init__.py
+++ b/truegrad/nn/__init__.py
@@ -1,12 +1,13 @@
-from typing import Any, List
+from typing import List
import torch
import torch.nn as nn
-from torch.utils._pytree import tree_map
-from truegrad.functional import add, gather, mul, wrap
+from truegrad.functional import TrueGradParameter, add, gather, is_tgparam, mul
from truegrad.nn import functional
+TrueGradParameter = TrueGradParameter
+is_tgparam = is_tgparam
F = functional
@@ -114,14 +115,14 @@ def _apply_instance_norm(self, input):
class _LayerNorm(nn.Module):
- def __init__(self, dims: List[int], eps: float, broadcast: bool):
+ def __init__(self, normalized_shape: List[int], eps: float, broadcast: bool):
super(_LayerNorm, self).__init__()
- self.dims = dims
+ self.normalized_shape = normalized_shape
self.eps = eps
self.broadcast = broadcast
def forward(self, x):
- return F.layer_norm(x, self.dims, eps=self.eps, broadcast=self.broadcast)
+ return F.layer_norm(x, self.normalized_shape, eps=self.eps, broadcast=self.broadcast)
class LayerNorm(Normalization):
@@ -129,8 +130,9 @@ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=
broadcast: bool = False):
if device is not None or dtype is not None:
raise ValueError("device and dtype are not supported. Ensure both are set to None.")
- super(LayerNorm, self).__init__(_LayerNorm([-i - 1 for i, dim in enumerate(normalized_shape) if dim != 1], eps,
- broadcast),
+ if isinstance(normalized_shape, int):
+ normalized_shape = [normalized_shape]
+ super(LayerNorm, self).__init__(_LayerNorm(normalized_shape, eps, broadcast),
normalized_shape, elementwise_affine)
@@ -159,15 +161,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
-class Embedding(nn.Module):
- def __init__(self, num_embeddings, embedding_dim, **kwargs):
- if kwargs:
- raise ValueError(f"{kwargs} are not supported.")
- super(Embedding, self).__init__()
- self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
-
+class Embedding(nn.Embedding):
def forward(self, input: torch.Tensor) -> torch.Tensor:
- return gather(input, self.weight)
+ return F.embedding(input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq,
+ self.sparse)
class Conv1d(nn.Conv1d):
@@ -196,53 +193,3 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
modules = (Embedding, Linear, LayerNorm, LayerNorm1d, LayerNorm2d, LayerNorm3d, InstanceNorm1d, InstanceNorm2d,
InstanceNorm3d, BatchNorm1d, BatchNorm2d, BatchNorm3d)
-
-
-def is_tgparam(param: nn.Parameter):
- if isinstance(param, TrueGradParameter):
- return True
- if isinstance(param, nn.Parameter) and hasattr(param, "activated"):
- return True
- return False
-
-
-_base_torch_function = nn.Parameter.__torch_function__
-
-
-class TrueGradParameter(nn.Parameter):
- activated: bool
-
- @staticmethod
- def __new__(cls, data=None, requires_grad=True):
- if data is None:
- data = torch.zeros(())
- out = torch.nn.Parameter._make_subclass(cls, data, requires_grad)
- out.activated = False
- return out
-
- def __repr__(self):
- return f"TrueGradParameter({self.data})"
-
- @classmethod
- def __torch_function__(cls, func, types, args=(), kwargs=None):
- if kwargs is None:
- kwargs = {}
-
- def base(a, k):
- return _base_torch_function(func, types, a, k)
-
- if all(not is_tgparam(a) or a.activated for a in list(args) + list(kwargs.values())):
- return base(args, kwargs)
- out = base(tree_map(_unpack, args), tree_map(_unpack, kwargs))
- for a in list(args) + list(kwargs.values()):
- if is_tgparam(a):
- a.activated = False
- if not isinstance(out, torch.Tensor):
- return out
- return wrap(base, args, kwargs)
-
-
-def _unpack(x: Any) -> Any:
- if is_tgparam(x) and not x.activated:
- x.activated = True
- return x
diff --git a/truegrad/nn/functional.py b/truegrad/nn/functional.py
index 1fbf014..2e78877 100644
--- a/truegrad/nn/functional.py
+++ b/truegrad/nn/functional.py
@@ -1,43 +1,20 @@
-import functools
+import torch.autograd
+
import math
import typing
import warnings
from typing import Callable, List, Optional, Tuple, Union
+import torch.autograd
import torch.autograd
from torch import Tensor, nn
-from torch.nn import functional as F, grad
-
-from truegrad.functional import add, chunk, einsum, matmul, mul, reshape, split, transpose
-
-_torch_functional = {k: getattr(F, k) for k in dir(F)}
-_torch = {k: getattr(torch, k) for k in dir(torch)}
-_inside_call = {}
-
+from torch.nn import functional as F
-def call_torch(fn: Callable, name: Optional[str] = None):
- if name is None:
- name = fn.__name__
- _inside_call[fn] = 0
-
- def _fn(*args, **kwargs):
- _inside_call[fn] += 1
- if _inside_call[fn] == 1:
- out = fn(*args, **kwargs)
- elif _inside_call[fn] == 2:
- out = _torch_functional[name](*args, **kwargs)
- elif _inside_call[fn] == 3:
- out = _torch[name](*args, **kwargs)
- else:
- raise ValueError
- _inside_call[fn] -= 1
- return out
-
- return _fn
+from truegrad.functional import (add, chunk, convnd, einsum, mul, reshape, simple_wrap,
+ split)
def no_parameter(fn: Callable):
- @functools.partial(call_torch, name=fn.__name__)
def _fn(*args, **kwargs):
for i, arg in enumerate(args):
if isinstance(arg, nn.Parameter):
@@ -208,7 +185,6 @@ def alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace
return F.alpha_dropout(input, p, training, inplace)
-@call_torch
def batch_norm(input: Tensor, running_mean: typing.Optional[Tensor],
running_var: typing.Optional[Tensor], weight: typing.Optional[Tensor] = None,
bias: typing.Optional[Tensor] = None, training: bool = False, momentum: float = 0.1,
@@ -221,7 +197,6 @@ def batch_norm(input: Tensor, running_mean: typing.Optional[Tensor],
return input
-@call_torch
def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: typing.Optional[Tensor] = None):
batch_dims = ''.join(chr(ord('a') + i) for i in range(input1.ndim - 1))
x = einsum(f'{batch_dims}x,{batch_dims}y,zxy->{batch_dims}z', input1, input2, weight)
@@ -260,58 +235,24 @@ def channel_shuffle(input: Tensor, groups: int):
return F.channel_shuffle(input, groups)
-class _ConvNdFn(torch.autograd.Function):
- @staticmethod
- def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], args) -> torch.Tensor:
- if weight.requires_grad:
- ctx.save_for_backward(input, weight, bias)
- ctx.args = args
- dim = input.dim() - 2 # Batch, Feature, *Data
- return getattr(F, f"conv{dim}d")(input, weight, bias, *args)
-
- @staticmethod
- def backward(ctx, dy: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], None]:
- if not ctx.saved_tensors:
- return None, None, None, None
- inp, wgt, bias = ctx.saved_tensors
- dim = inp.dim() - 2
- summed = [0] + list(range(2, 2 + dim))
-
- dx = getattr(grad, f"conv{dim}d_input")(inp.size(), wgt, dy, *ctx.args)
- dw = getattr(grad, f"conv{dim}d_weight")(inp, wgt.size(), dy, *ctx.args)
- db = None if bias is None else dy.sum(summed)
-
- if isinstance(wgt, nn.Parameter) or isinstance(bias, nn.Parameter):
- dy_sq = dy.square() * dy.size(0)
- if isinstance(wgt, nn.Parameter):
- wgt.sum_grad_squared = getattr(grad, f"conv{dim}d_weight")(inp.square(), wgt.size(), dy_sq, *ctx.args)
- if isinstance(bias, nn.Parameter):
- bias.sum_grad_squared = dy_sq.sum(summed)
- return dx, dw, db, None
-
-
-@call_torch
def _convnd(input: Tensor, weight: Tensor, bias: Optional[Tensor], dim: int, *args):
if input.dim() != dim + 2:
raise ValueError(f"Input has {input.dim()} dimensions, but expected {dim + 2} dimensions for conv{dim}d.")
- return _ConvNdFn.apply(input, weight, bias, args)
+ return convnd(input, weight, bias, args)
-@call_torch
def conv1d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: int = 1,
padding: Union[str, int] = "valid",
dilation: int = 1, groups: int = 1):
return _convnd(input, weight, bias, 1, stride, padding, dilation, groups)
-@call_torch
def conv2d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: int = 1,
padding: Union[str, int] = "valid",
dilation: int = 1, groups: int = 1):
return _convnd(input, weight, bias, 2, stride, padding, dilation, groups)
-@call_torch
def conv3d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: int = 1,
padding: Union[str, int] = "valid",
dilation: int = 1, groups: int = 1):
@@ -391,11 +332,12 @@ def elu_(input: Tensor, alpha: float = 1.0):
return F.elu_(input, alpha)
-@no_parameter
def embedding(input: Tensor, weight: Tensor, padding_idx: typing.Optional[int] = None,
max_norm: typing.Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False,
sparse: bool = False):
- return F.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
+ args = [input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse]
+ out = F.embedding(*args)
+ return simple_wrap(F.embedding, out, *args)
@no_parameter
@@ -440,7 +382,6 @@ def grid_sample(input: Tensor, grid: Tensor, mode: str = bilinear, padding_mode:
return F.grid_sample(input, grid, mode, padding_mode, align_corners)
-@call_torch
def group_norm(input: Tensor, num_groups: int, weight: typing.Optional[Tensor] = None,
bias: typing.Optional[Tensor] = None, eps: float = 1e-05):
x = F.group_norm(input, num_groups, None, None, eps)
@@ -493,7 +434,6 @@ def huber_loss(input: Tensor, target: Tensor, reduction: str = "mean", delta: fl
return F.huber_loss(input, target, reduction, delta)
-@call_torch
def instance_norm(input: Tensor, running_mean: typing.Optional[Tensor] = None,
running_var: typing.Optional[Tensor] = None, weight: typing.Optional[Tensor] = None,
bias: typing.Optional[Tensor] = None, use_input_stats: bool = True, momentum: float = 0.1,
@@ -526,7 +466,6 @@ def l1_loss(input: Tensor, target: Tensor, size_average: typing.Optional[bool] =
return F.l1_loss(input, target, size_average, reduce, reduction)
-@call_torch
def layer_norm(input: Tensor, normalized_shape: typing.List[int], weight: typing.Optional[Tensor] = None,
bias: typing.Optional[Tensor] = None, eps: float = 1e-05, broadcast: bool = True):
if broadcast:
@@ -549,9 +488,9 @@ def leaky_relu_(input: Tensor, negative_slope: float = 0.01):
return F.leaky_relu_(input, negative_slope)
-@call_torch
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]):
- input = matmul(input, transpose(weight, (0, 1)))
+ batch_dims = ''.join(chr(ord('a') + i) for i in range(input.ndim - 1))
+ input = einsum(f"{batch_dims}y,zy->{batch_dims}z", input, weight)
if bias is None:
return input
return add(input, bias)
diff --git a/truegrad/optim.py b/truegrad/optim.py
index 28dffc0..b514289 100644
--- a/truegrad/optim.py
+++ b/truegrad/optim.py
@@ -80,7 +80,7 @@ def step(self, closure=None):
adam_update = exp_avg / (exp_avg_sq / (1 - beta2 ** step)).sqrt().add_(group['eps'])
if group["graft"] and not do_adam:
- alpha = alpha * adam_update.norm() / update.norm()
+ alpha = alpha * adam_update.norm() / update.norm().add_(group['eps'])
elif do_adam:
update = adam_update
diff --git a/truegrad/utils.py b/truegrad/utils.py
index 0466cda..48c25d8 100644
--- a/truegrad/utils.py
+++ b/truegrad/utils.py
@@ -1,4 +1,8 @@
+import collections
+import typing
+
import torch
+from torch import overrides
import truegrad
from truegrad.nn import TrueGradParameter
@@ -16,6 +20,26 @@ def _apply_fn(module: torch.nn.Module):
_apply_fn(mod)
+def from_x(name: str, fn: typing.Callable, module):
+ calls = [0]
+ original = getattr(module, name)
+
+ def _fn(*args, **kwargs):
+ calls[0] += 1
+ if calls[0] == 1:
+ try:
+ return fn(*args, **kwargs)
+ except:
+ return original(*args, **kwargs)
+ finally:
+ calls[0] -= 1
+ out = original(*args, **kwargs)
+ calls[0] -= 1
+ return out
+
+ return _fn
+
+
def _patch(tg, th):
tg_dir = dir(tg)
for name in dir(th):
@@ -26,10 +50,11 @@ def _patch(tg, th):
continue
if item.__module__ != tg.__name__:
continue
- setattr(th, name, item)
+ setattr(th, name, from_x(name, item, th))
def patch_torch():
_patch(truegrad.nn.functional, torch.nn.functional)
_patch(truegrad.nn.functional, torch)
_patch(truegrad.nn, torch.nn)
+ overrides.has_torch_function_variadic = lambda *x: False