From 4279d61f40e2348a5e622408ddb99e350502f69a Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Tue, 29 Nov 2022 08:27:46 +0100 Subject: [PATCH] feat(functional): improve grad accum, fix einsum backwd, allow full patching --- README.md | 126 ++++++++++-- setup.py | 2 +- truegrad/functional.py | 393 +++++++++++++++++++++++++------------- truegrad/nn/__init__.py | 79 ++------ truegrad/nn/functional.py | 85 ++------- truegrad/optim.py | 2 +- truegrad/utils.py | 27 ++- 7 files changed, 428 insertions(+), 286 deletions(-) diff --git a/README.md b/README.md index e782201..0cd4cb8 100644 --- a/README.md +++ b/README.md @@ -14,12 +14,13 @@ python3 -m pip install truegrad TrueGrad supports various backends, each with their own tradeoffs: -| Name | Advantages | Disadvantages | -|----------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------| -| [truegrad.nn](#nn) | * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported
* Custom forward/backward for some fused functions
* Optimized backward passes | * Limited applicability - custom modules can't be used
* Requires code modification | -| [truegrad.utils.patch_torch](#patch-torch) | * Uses truegrad.nn under the hood
* Works for many (off-the-shelf!) torch models
* No code modification necessary | * Uncertainty if model is compatible | -| [backpack](#backpack) | * Highest stability
* Loud warnings and errors
* Battle-tested
* Simple to extend further | * High memory usage
* High compute usage
* Sparse support for torch operations | -| [truegrad.utils.patch_model](#patch-custom-models) | * Best compatibility | * Fails silently on fused functions
* More costly than truegrad.nn | +| Name | Advantages | Disadvantages | +|----------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------| +| [truegrad.nn](#nn) | * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported
* Custom forward/backward for some fused functions
* Optimized backward passes | * Limited applicability - custom modules can't be used
* Requires code modification | +| [truegrad.utils.patch_torch](#patch-torch) | * Uses truegrad.nn under the hood
* Works for many (off-the-shelf!) torch models
* No code modification necessary | * Uncertainty if model is compatible | +| [backpack](#backpack) | * Highest stability
* Loud warnings and errors
* Battle-tested
* Simple to extend further | * High memory usage
* High compute usage
* Sparse support for torch operations | +| [truegrad.utils.patch_model](#patch-custom-models) | * Works with custom models | * Fails silently on fused functions
* ~50% to 100% slower than truegrad.nn | +| [patch_torch + patch_model](#Full Patching) | * Best compatibility
* Reduced overheads compared to `patch_model` (by falling back to faster pre-patched `patch_torch` where available) | * Fails silently on fused functions outside of torch.nn
* Slower than truegrad.nn when truegrad.nn would've been enough | Below, you'll find examples for each of these backends, as well as a [general strategy](#partial-truegrad) allowing partial application of TrueGrad. @@ -47,6 +48,7 @@ while True: input = torch.randn((16, 1)) model(input).mean().backward() optim.step() + optim.zero_grad() ``` ### Patch Torch @@ -77,11 +79,46 @@ while True: loss = torch.nn.functional.cross_entropy(model(inp), tgt) loss.backward() optim.step() + optim.zero_grad() i += 1 if i % 5 == 0: print(i, loss.item()) ``` +Similarly, most huggingface transformers work out of the box: + +```PYTHON +import torch +import transformers +from torch.nn import functional as F + +from truegrad.optim import TGAdamW +from truegrad.utils import patch_torch + +patch_torch() # only added line to get truegrad statistics for TGAdamW + +model = transformers.BertModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2") # any existing model +tokenizer = transformers.BertTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2") + +optim = TGAdamW(model.parameters()) + +# constant input to overfit +input = tokenizer(["Hello World!"], return_tensors="pt") + +# training loop as normal +while True: + out = model(**input) + loss = F.l1_loss(out[0], torch.ones_like(out[0])) + loss.backward() + optim.step() + optim.zero_grad() + print(loss.item()) +``` + +Note that this works even though transformers have custom modules, which could cause issues. The key factor is that all +parameters come from `torch.nn.Module`'s, which are patched by `patch_torch()`. Therefore, truegrad handles all +parameter usages. Therefore, any composition of `torch.nn.Module`'s makes for a truegrad-compatible model. + ### BackPack The most stable although also memory hungry method to compute TrueGrad statistics is to use @@ -119,6 +156,7 @@ while True: loss = lossfunc(model(inp), tgt) loss.backward() optim.step() + optim.zero_grad() i += 1 if i % 5 == 0: print(i, loss.item()) @@ -141,21 +179,78 @@ and `torch.nn.MultiheadAttention`. However, unfused functions which directly acc work well. Therefore, torch.nn.Linear and HuggingFace's attention work as expected. ```PYTHON -import transformers -from truegrad.utils import patch_model +import torch from truegrad.optim import TGAdamW +from truegrad.utils import patch_model +from torchvision.models import alexnet -model = transformers.BertModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2") # any existing model -tokenizer = transformers.BertTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2") +model = alexnet() # patch_model can't handle fused ops like VGG's and ResNet's BatchNorm +optim = TGAdamW(model.parameters()) + +# replace inplace ops like nn.ReLU(inplace=True) where possible +for mod in model.modules(): + if hasattr(mod, "inplace"): + mod.inplace = False patch_model(model) # replace torch.nn.Parameter with truegrad.nn.Parameter -optim = TGAdamW(model.parameters()) # truegrad.optim.TGAdamW instead of torch.optim.AdamW -# training loop as normal -for sample in ["Hello", "World", "!"]: - out = model(**tokenizer([sample], return_tensors="pt")) - out[0].mean().backward() +# constant input/output to overfit +inp = torch.randn((2, 3, 224, 224)) +tgt = torch.randint(0, 1000, (2,)) + +# standard training loop +i = 0 +while True: + # "SumGradSquared" computes the sum of the squared gradient + loss = torch.nn.functional.cross_entropy(model(inp), tgt) + loss.backward() optim.step() + optim.zero_grad() + i += 1 + if i % 5 == 0: + print(i, loss.item()) +``` + +### Full Patching + +One way of avoiding [truegrad.utils.patch_model](#patch-custom-models)'s downsides when working with off-the-shelf +models containing custom parameters, such as [lucidrains' ViT's](https://github.com/lucidrains/vit-pytorch/) is to also +`patch_torch`. This takes care of many fused functions, such as LayerNorm, while still allowing full flexibility in +model design. + +```PYTHON +import torch +from vit_pytorch.levit import LeViT +from truegrad.utils import patch_torch, patch_model +from truegrad.optim import TGAdamW + +patch_torch() # before model instantiation + +levit = LeViT( + image_size=224, + num_classes=1000, + stages=3, # number of stages + dim=(256, 384, 512), # dimensions at each stage + depth=4, # transformer of depth 4 at each stage + heads=(4, 6, 8), # heads at each stage + mlp_mult=2, + dropout=0.1 + ) + +opt = TGAdamW(levit.parameters()) + +patch_model(levit) # replace torch.nn.Parameter with truegrad.nn.TrueGradParameter + +# constant input to overfit +img = torch.randn(1, 3, 224, 224) + +# standard training loop +while True: + loss = levit(img).square().mean() + loss.backward() + opt.step() + opt.zero_grad() + print(loss.item()) ``` ### Partial TrueGrad @@ -186,6 +281,7 @@ while True: loss = model(input).mean() loss.backward() optim.step() + optim.zero_grad() i += 1 if i % 5 == 0: print(i, loss.item()) diff --git a/setup.py b/setup.py index ebc72b5..01d6d7e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ name='truegrad', license='BSD', description='PyTorch interface for TrueGrad-AdamW', - version='2.0.0', + version='2.1.0', long_description=README, url='https://github.com/clashluke/truegrad', packages=setuptools.find_packages(), diff --git a/truegrad/functional.py b/truegrad/functional.py index f2a2141..c28e751 100644 --- a/truegrad/functional.py +++ b/truegrad/functional.py @@ -1,31 +1,97 @@ +import contextlib import typing -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Optional, Tuple import torch +from torch import Tensor, nn +from torch.nn import functional as F, grad from torch.utils._pytree import tree_map -def _unpack(x: Any) -> Any: +# TrueGradParameter + + +def is_tgparam(param: nn.Parameter): + if isinstance(param, TrueGradParameter): + return True + if isinstance(param, nn.Parameter) and hasattr(param, "activated"): + return True + return False + + +def unpack_tg_param(x: Any) -> Any: + if is_tgparam(x) and not x.activated: + x.activated = True + return x + + +@contextlib.contextmanager +def activate_tg_params(*tensors): + for t in tensors: + unpack_tg_param(t) + yield + for t in tensors: + if is_tgparam(t): + t.activated = False + + +_parameter_function = nn.Parameter.__torch_function__ + + +class TrueGradParameter(nn.Parameter): + activated: bool + + @staticmethod + def __new__(cls, data=None, requires_grad=True): + if data is None: + data = torch.zeros(()) + out = torch.nn.Parameter._make_subclass(cls, data, requires_grad) + out.activated = False + return out + + def __repr__(self): + return f"TrueGradParameter({self.data})" + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + def base(a, k): + return _parameter_function(func, types, a, k) + + if all(not is_tgparam(a) or a.activated for a in list(args) + list(kwargs.values())): + return base(args, kwargs) + with activate_tg_params(*args, *kwargs.values()): + out = base(tree_map(unpack_tg_param, args), tree_map(unpack_tg_param, kwargs)) + if not isinstance(out, Tensor): + return out + return wrap(base, out, args, kwargs) + + +# TrueGradTensor + +def unpack_tg_tensor(x: Any) -> Any: if isinstance(x, TrueGradTensor): return x.data return x -_base_torch_function = torch.Tensor.__torch_function__ +_tensor_function = Tensor.__torch_function__ -class TrueGradTensor(torch.Tensor): - sum_grad_squared: torch.Tensor - data: torch.Tensor +class TrueGradTensor(Tensor): + sum_grad_squared: Tensor + data: Tensor requires_grad: bool __slots__ = ['sum_grad_squared', "data", "requires_grad"] @staticmethod - def __new__(cls, data: torch.Tensor): + def __new__(cls, data: Tensor): meta = data.new_empty((0,)) meta.set_(meta.storage(), 0, data.size(), data.stride()) - r = torch.Tensor._make_subclass(cls, meta, data.requires_grad) + r = Tensor._make_subclass(cls, meta, data.requires_grad) r.data = data r.sum_grad_squared = None r.activated = False @@ -39,222 +105,253 @@ def __repr__(self): def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - out = _base_torch_function(func, [], tree_map(_unpack, args), tree_map(_unpack, kwargs)) - return out + return _tensor_function(func, [], tree_map(unpack_tg_tensor, args), tree_map(unpack_tg_tensor, kwargs)) +# utils + +def valid_attr(wgt: nn.Parameter, attr: str = "sum_grad_squared"): + return hasattr(wgt, attr) and getattr(wgt, attr) is not None + + +def add_or_set(wgt: nn.Parameter, new: torch.Tensor, attr: str = "sum_grad_squared"): + if hasattr(wgt, attr) and getattr(wgt, attr) is not None: + new = getattr(wgt, attr) + new + setattr(wgt, attr, new) + + +# Autograd Functions + class MulFn(torch.autograd.Function): @staticmethod - def forward(ctx, inp: torch.Tensor, weight: torch.Tensor): - if weight.requires_grad: - ctx.save_for_backward(inp) - ctx.weight = weight - return inp * weight + def forward(ctx, inp: Tensor, weight: Tensor): + with activate_tg_params(inp, weight): + if weight.requires_grad: + ctx.save_for_backward(inp) + ctx.weight = weight + return inp * weight @staticmethod - def backward(ctx, dy: torch.Tensor): + def backward(ctx, dy: Tensor): if not ctx.saved_tensors: return None, None inp, = ctx.saved_tensors weight = ctx.weight - diff = inp.ndim - weight.ndim - summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1] - weight_grad = dy * inp - weight.sum_grad_squared = weight_grad.square() - if summed: - weight_grad = weight_grad.sum(summed) - weight.sum_grad_squared = weight.sum_grad_squared.sum(summed) - weight.sum_grad_squared = weight.sum_grad_squared.reshape(weight.size()) * dy.size(0) + with activate_tg_params(inp, weight): + diff = inp.ndim - weight.ndim + summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1] + weight_grad = dy * inp + sum_grad_squared = weight_grad.square() + if summed: + weight_grad = weight_grad.sum(summed) + sum_grad_squared = sum_grad_squared.sum(summed) + add_or_set(weight, sum_grad_squared.reshape(weight.size()) * dy.size(0)) return dy * weight, weight_grad.reshape(weight.size()) class AddFn(torch.autograd.Function): @staticmethod - def forward(ctx, inp: torch.Tensor, weight: torch.Tensor): - if weight.requires_grad: - diff = inp.ndim - weight.ndim - ctx.summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1] - ctx.batch_size = inp.size(0) - ctx.weight = weight + def forward(ctx, inp: Tensor, weight: Tensor): + with activate_tg_params(inp, weight): + if weight.requires_grad: + diff = inp.ndim - weight.ndim + ctx.summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1] + ctx.batch_size = inp.size(0) + ctx.weight = weight return inp + weight @staticmethod - def backward(ctx, dy: torch.Tensor): + def backward(ctx, dy: Tensor): if not hasattr(ctx, "weight"): return None, None weight = ctx.weight - weight_grad = dy - weight.sum_grad_squared = dy.square() - if ctx.summed: - weight_grad = weight_grad.sum(ctx.summed) - weight.sum_grad_squared = weight.sum_grad_squared.sum(ctx.summed) - weight.sum_grad_squared = weight.sum_grad_squared.reshape(weight.size()) * dy.size(0) + with activate_tg_params(weight): + weight_grad = dy + sum_grad_squared = dy.square() + if ctx.summed: + weight_grad = weight_grad.sum(ctx.summed) + sum_grad_squared = sum_grad_squared.sum(ctx.summed) + add_or_set(weight, sum_grad_squared.reshape(weight.size()) * dy.size(0)) return dy, weight_grad.reshape(weight.size()) class EinsumFn(torch.autograd.Function): @staticmethod - def forward(ctx, spec: str, inp: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: - if weight.requires_grad: - ctx.save_for_backward(inp) - ctx.weight = weight - ctx.spec = spec - return torch.einsum(spec, inp, weight) + def forward(ctx, spec: str, inp: Tensor, weight: Tensor) -> Tensor: + with activate_tg_params(inp, weight): + if weight.requires_grad: + ctx.save_for_backward(inp, weight) + ctx.spec = spec + return torch.clone(torch.einsum(spec, inp, weight)) @staticmethod - def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor, torch.Tensor]: + def backward(ctx, dy: Tensor) -> Tuple[None, Tensor, Tensor]: if not ctx.saved_tensors: return None, None, None - inp, = ctx.saved_tensors - wgt = ctx.weight - inputs, output = ctx.spec.split('->') - lhs, rhs = inputs.split(',') + inp, wgt = ctx.saved_tensors + with activate_tg_params(inp, wgt): + inputs, output = ctx.spec.split('->') + lhs, rhs = inputs.split(',') + + d_wgt = torch.einsum(f'{lhs},{output}->{rhs}', inp, dy).contiguous() + add_or_set(wgt, torch.einsum(f'{lhs},{output}->{rhs}', inp.square(), dy.square()).contiguous()) + d_inp = torch.einsum(f"{rhs},{output}->{lhs}", wgt, dy).contiguous() - d_wgt = torch.einsum(f'{lhs},{output}->{rhs}', inp, dy) - wgt.sum_grad_squared = torch.einsum(f'{lhs},{output}->{rhs}', inp.square(), dy.square() * inp.size(0)) - d_inp = torch.einsum(f"{rhs},{output}->{lhs}", wgt, dy) return None, d_inp, d_wgt class GatherFn(torch.autograd.Function): @staticmethod - def forward(ctx, inp: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: - if weight.requires_grad: - ctx.save_for_backward(inp) - ctx.weight = weight + def forward(ctx, inp: Tensor, weight: Tensor) -> Tensor: + with activate_tg_params(inp, weight): + if weight.requires_grad: + ctx.save_for_backward(inp) + ctx.weight = weight return torch.gather(weight, 0, inp) @staticmethod - def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]: + def backward(ctx, dy: Tensor) -> Tuple[None, Tensor]: if not ctx.saved_tensors: return None, None inp, = ctx.saved_tensors wgt = ctx.weight - wgt_grad = torch.zeros_like(wgt) - wgt.sum_grad_squared = wgt_grad.scatter_add(0, inp, dy.square()) - wgt_grad.scatter_add_(0, inp, dy) + with activate_tg_params(inp, wgt): + wgt_grad = torch.zeros_like(wgt) + add_or_set(wgt, wgt_grad.scatter_add(0, inp, dy.square())) + wgt_grad.scatter_add_(0, inp, dy) return None, wgt_grad class ReshapeFn(torch.autograd.Function): @staticmethod - def forward(ctx, weight: torch.Tensor, new_shape: List[int]) -> torch.Tensor: - out = TrueGradTensor(weight.reshape(new_shape).detach().requires_grad_(True)) - if weight.requires_grad: - ctx.save_for_backward(weight) - ctx.out = out - ctx.original_shape = weight.size() - return out + def forward(ctx, weight: Tensor, new_shape: List[int]) -> Tensor: + with activate_tg_params(weight): + out = TrueGradTensor(weight.reshape(new_shape).detach().requires_grad_(True)) + if weight.requires_grad: + ctx.save_for_backward(weight, out) + return out @staticmethod - def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]: + def backward(ctx, dy: Tensor) -> Tuple[None, Tensor]: if not ctx.saved_tensors: return None, None - wgt, = ctx.saved_tensors - if ctx.out.sum_grad_squared is not None: - wgt.sum_grad_squared = ctx.out.sum_grad_squared.reshape(ctx.original_shape) - return dy.reshape(ctx.original_shape), None + wgt, out = ctx.saved_tensors + with activate_tg_params(wgt): + if valid_attr(out): + add_or_set(wgt, ctx.out.sum_grad_squared) + return dy.reshape(wgt.size()), None class TransposeFn(torch.autograd.Function): @staticmethod - def forward(ctx, weight: torch.Tensor, dims: typing.List[int]) -> torch.Tensor: - out = TrueGradTensor(weight.transpose(*dims).detach().requires_grad_(True)) - if weight.requires_grad: - ctx.save_for_backward(weight) - ctx.out = out - ctx.dims = dims + def forward(ctx, weight: Tensor, dims: typing.List[int]) -> Tensor: + with activate_tg_params(weight): + out = TrueGradTensor(weight.transpose(*dims).detach().requires_grad_(True)) + if weight.requires_grad: + ctx.save_for_backward(weight, out) + ctx.dims = dims return out @staticmethod - def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]: + def backward(ctx, dy: Tensor) -> Tuple[None, Tensor]: if not ctx.saved_tensors: return None, None - wgt, = ctx.saved_tensors - if ctx.out.sum_grad_squared is not None: - wgt.sum_grad_squared = ctx.out.sum_grad_squared.transpose(*ctx.dims) + wgt, out = ctx.saved_tensors + with activate_tg_params(wgt): + if valid_attr(out): + add_or_set(wgt, out.sum_grad_squared.transpose(*ctx.dims())) return dy.transpose(*ctx.dims), None class ChunkFn(torch.autograd.Function): @staticmethod - def forward(ctx, weight: torch.Tensor, chunks: int, dim: int): - out = tuple(TrueGradTensor(c) for c in weight.chunk(chunks, dim)) - if weight.requires_grad: - ctx.save_for_backward(weight) - ctx.out = out - ctx.dim = dim + def forward(ctx, weight: Tensor, chunks: int, dim: int): + with activate_tg_params(weight): + out = tuple(TrueGradTensor(c) for c in weight.chunk(chunks, dim)) + if weight.requires_grad: + ctx.save_for_backward(weight, out) + ctx.dim = dim return out @staticmethod - def backward(ctx, *dy: torch.Tensor): + def backward(ctx, *dy: Tensor): if not ctx.saved_tensors: return None, None, None - wgt, = ctx.saved_tensors - wgt.sum_grad_squared = torch.cat([o.sum_grad_squared for o in ctx.out], dim=ctx.dim) + wgt, out = ctx.saved_tensors + with activate_tg_params(wgt): + if all(valid_attr(o) for o in out): + add_or_set(wgt, torch.cat([o.sum_grad_squared for o in out], dim=ctx.dim)) return torch.cat(dy, dim=ctx.dim), None, None class SplitFn(torch.autograd.Function): @staticmethod - def forward(ctx, weight: torch.Tensor, split_size: int, dim: int): - out = tuple(TrueGradTensor(c) for c in weight.split(split_size, dim)) - if weight.requires_grad: - ctx.save_for_backward(weight) - ctx.out = out - ctx.dim = dim + def forward(ctx, weight: Tensor, split_size: int, dim: int): + with activate_tg_params(weight): + out = tuple(TrueGradTensor(c) for c in weight.split(split_size, dim)) + if weight.requires_grad: + ctx.save_for_backward(weight, out) + ctx.dim = dim return out @staticmethod - def backward(ctx, *dy: torch.Tensor): + def backward(ctx, *dy: Tensor): if not ctx.saved_tensors: return None, None, None - wgt, = ctx.saved_tensors - wgt.sum_grad_squared = torch.cat([o.sum_grad_squared for o in ctx.out], dim=ctx.dim) + wgt, out = ctx.saved_tensors + with activate_tg_params(wgt): + if all(valid_attr(o) for o in out): + add_or_set(wgt, torch.cat([o.sum_grad_squared for o in out], dim=ctx.dim)) return torch.cat(dy, dim=ctx.dim), None, None class ExpandFn(torch.autograd.Function): @staticmethod - def forward(ctx, weight: torch.Tensor, new_shape: List[int]) -> torch.Tensor: - out = TrueGradTensor(weight.expand(new_shape)) - if weight.requires_grad: - ctx.save_for_backward(weight) - ctx.out = out - ctx.summed = [i for i, d in enumerate(new_shape) if d != -1] + def forward(ctx, weight: Tensor, new_shape: List[int]) -> Tensor: + with activate_tg_params(weight): + out = TrueGradTensor(weight.expand(new_shape)) + if weight.requires_grad: + ctx.save_for_backward(weight, out) + ctx.summed = [i for i, d in enumerate(new_shape) if d != -1] return out @staticmethod - def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]: + def backward(ctx, dy: Tensor) -> Tuple[None, Tensor]: if not ctx.saved_tensors: return None, None - wgt, = ctx.saved_tensors - if ctx.out.sum_grad_squared is not None and ctx.summed: - wgt.sum_grad_squared = ctx.out.sum_grad_squared.sum(ctx.summed) - return dy.sum(ctx.summed) + wgt, out = ctx.saved_tensors + with activate_tg_params(wgt): + if valid_attr(out): + sum_grad_squared = out.sum_grad_squared + if ctx.summed: + sum_grad_squared = sum_grad_squared.sum(ctx.summed) + add_or_set(wgt, sum_grad_squared) + if ctx.summed: + return dy.sum(ctx.summed) + return dy class WrapFn(torch.autograd.Function): @staticmethod - def forward(ctx, fn, args, kwargs) -> torch.Tensor: + def forward(ctx, fn, out, args, kwargs) -> Tensor: ctx.fn = fn ctx.args = args ctx.kwargs = kwargs - return fn(*args, **kwargs) + return out @staticmethod - def backward(ctx, dy: torch.Tensor) -> Tuple[None, None, None, None]: - def _backward(fn: Callable[[torch.Tensor], torch.Tensor], attr: str): - def _fn(x: torch.Tensor): - if isinstance(x, torch.nn.Parameter): - x = x.data - if not isinstance(x, torch.Tensor) or not torch.is_floating_point(x): - return x - x = fn(x.detach()) - x.requires_grad_(True) + def backward(ctx, dy: Tensor) -> Tuple[None, Tensor, None, None]: + def _fn(x: Tensor): + if isinstance(x, nn.Parameter): + x = x.data + if not isinstance(x, Tensor) or not torch.is_floating_point(x): return x + x = torch.square(x.detach()).detach() + x.requires_grad_(True) + return x + with activate_tg_params(*ctx.args, *ctx.kwargs.values()): args = tree_map(_fn, ctx.args) kwargs = tree_map(_fn, ctx.kwargs) @@ -263,17 +360,46 @@ def _fn(x: torch.Tensor): torch.autograd.backward(out, tree_map(_fn, dy)) for p, a in zip(list(ctx.args) + list(ctx.kwargs.values()), list(args) + list(kwargs.values())): - if not isinstance(p, torch.nn.Parameter): + if not hasattr(a, "grad"): continue - if hasattr(p, attr) and getattr(p, attr) is not None: - a.grad = getattr(p, attr) + a.grad - setattr(p, attr, a.grad) + add_or_set(p, a.grad.contiguous()) - _backward(torch.square, "sum_grad_squared") - _backward(lambda x: x, "grad") + return None, dy, None, None - return None, None, None, None +class ConvNdFn(torch.autograd.Function): + @staticmethod + def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], args) -> Tensor: + with activate_tg_params(input, weight, bias): + if weight.requires_grad: + ctx.save_for_backward(input, weight, bias) + ctx.args = args + dim = input.dim() - 2 # Batch, Feature, *Data + return getattr(F, f"conv{dim}d")(input, weight, bias, *args).contiguous() + + @staticmethod + def backward(ctx, dy: Tensor) -> Tuple[Tensor, Tensor, Optional[Tensor], None]: + if not ctx.saved_tensors: + return None, None, None, None + inp, wgt, bias = ctx.saved_tensors + with activate_tg_params(inp, wgt, bias): + dim = inp.dim() - 2 + summed = [0] + list(range(2, 2 + dim)) + + dx = getattr(grad, f"conv{dim}d_input")(inp.size(), wgt, dy, *ctx.args) + dw = getattr(grad, f"conv{dim}d_weight")(inp, wgt.size(), dy, *ctx.args) + db = None if bias is None else dy.sum(summed) + + if isinstance(wgt, nn.Parameter) or isinstance(bias, nn.Parameter): + dy_sq = dy.square() * dy.size(0) + if isinstance(wgt, nn.Parameter): + wgt.sum_grad_squared = getattr(grad, f"conv{dim}d_weight")(inp.square(), wgt.size(), dy_sq, *ctx.args) + if isinstance(bias, nn.Parameter): + bias.sum_grad_squared = dy_sq.sum(summed) + return dx, dw, db, None + + +# "Normal" Functions mul = MulFn.apply add = AddFn.apply @@ -285,8 +411,17 @@ def _fn(x: torch.Tensor): split = SplitFn.apply expand = ExpandFn.apply wrap = WrapFn.apply +convnd = ConvNdFn.apply -def matmul(inp: torch.Tensor, wgt: torch.Tensor): +def matmul(inp: Tensor, wgt: Tensor): batch_dims = ''.join(chr(ord('a') + i) for i in range(inp.ndim - 1)) return einsum(f"{batch_dims}y,yz->{batch_dims}z", inp, wgt) + + +def simple_wrap(function: Callable, out: torch.Tensor, *args, **kwargs): + def _fn(a, k): + return function(*a, **k) + + _fn.abc = function.__name__ + return wrap(_fn, out, args, kwargs) diff --git a/truegrad/nn/__init__.py b/truegrad/nn/__init__.py index 39f9e50..42fb272 100644 --- a/truegrad/nn/__init__.py +++ b/truegrad/nn/__init__.py @@ -1,12 +1,13 @@ -from typing import Any, List +from typing import List import torch import torch.nn as nn -from torch.utils._pytree import tree_map -from truegrad.functional import add, gather, mul, wrap +from truegrad.functional import TrueGradParameter, add, gather, is_tgparam, mul from truegrad.nn import functional +TrueGradParameter = TrueGradParameter +is_tgparam = is_tgparam F = functional @@ -114,14 +115,14 @@ def _apply_instance_norm(self, input): class _LayerNorm(nn.Module): - def __init__(self, dims: List[int], eps: float, broadcast: bool): + def __init__(self, normalized_shape: List[int], eps: float, broadcast: bool): super(_LayerNorm, self).__init__() - self.dims = dims + self.normalized_shape = normalized_shape self.eps = eps self.broadcast = broadcast def forward(self, x): - return F.layer_norm(x, self.dims, eps=self.eps, broadcast=self.broadcast) + return F.layer_norm(x, self.normalized_shape, eps=self.eps, broadcast=self.broadcast) class LayerNorm(Normalization): @@ -129,8 +130,9 @@ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device= broadcast: bool = False): if device is not None or dtype is not None: raise ValueError("device and dtype are not supported. Ensure both are set to None.") - super(LayerNorm, self).__init__(_LayerNorm([-i - 1 for i, dim in enumerate(normalized_shape) if dim != 1], eps, - broadcast), + if isinstance(normalized_shape, int): + normalized_shape = [normalized_shape] + super(LayerNorm, self).__init__(_LayerNorm(normalized_shape, eps, broadcast), normalized_shape, elementwise_affine) @@ -159,15 +161,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return F.linear(x, self.weight, self.bias) -class Embedding(nn.Module): - def __init__(self, num_embeddings, embedding_dim, **kwargs): - if kwargs: - raise ValueError(f"{kwargs} are not supported.") - super(Embedding, self).__init__() - self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim)) - +class Embedding(nn.Embedding): def forward(self, input: torch.Tensor) -> torch.Tensor: - return gather(input, self.weight) + return F.embedding(input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, + self.sparse) class Conv1d(nn.Conv1d): @@ -196,53 +193,3 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: modules = (Embedding, Linear, LayerNorm, LayerNorm1d, LayerNorm2d, LayerNorm3d, InstanceNorm1d, InstanceNorm2d, InstanceNorm3d, BatchNorm1d, BatchNorm2d, BatchNorm3d) - - -def is_tgparam(param: nn.Parameter): - if isinstance(param, TrueGradParameter): - return True - if isinstance(param, nn.Parameter) and hasattr(param, "activated"): - return True - return False - - -_base_torch_function = nn.Parameter.__torch_function__ - - -class TrueGradParameter(nn.Parameter): - activated: bool - - @staticmethod - def __new__(cls, data=None, requires_grad=True): - if data is None: - data = torch.zeros(()) - out = torch.nn.Parameter._make_subclass(cls, data, requires_grad) - out.activated = False - return out - - def __repr__(self): - return f"TrueGradParameter({self.data})" - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - - def base(a, k): - return _base_torch_function(func, types, a, k) - - if all(not is_tgparam(a) or a.activated for a in list(args) + list(kwargs.values())): - return base(args, kwargs) - out = base(tree_map(_unpack, args), tree_map(_unpack, kwargs)) - for a in list(args) + list(kwargs.values()): - if is_tgparam(a): - a.activated = False - if not isinstance(out, torch.Tensor): - return out - return wrap(base, args, kwargs) - - -def _unpack(x: Any) -> Any: - if is_tgparam(x) and not x.activated: - x.activated = True - return x diff --git a/truegrad/nn/functional.py b/truegrad/nn/functional.py index 1fbf014..2e78877 100644 --- a/truegrad/nn/functional.py +++ b/truegrad/nn/functional.py @@ -1,43 +1,20 @@ -import functools +import torch.autograd + import math import typing import warnings from typing import Callable, List, Optional, Tuple, Union +import torch.autograd import torch.autograd from torch import Tensor, nn -from torch.nn import functional as F, grad - -from truegrad.functional import add, chunk, einsum, matmul, mul, reshape, split, transpose - -_torch_functional = {k: getattr(F, k) for k in dir(F)} -_torch = {k: getattr(torch, k) for k in dir(torch)} -_inside_call = {} - +from torch.nn import functional as F -def call_torch(fn: Callable, name: Optional[str] = None): - if name is None: - name = fn.__name__ - _inside_call[fn] = 0 - - def _fn(*args, **kwargs): - _inside_call[fn] += 1 - if _inside_call[fn] == 1: - out = fn(*args, **kwargs) - elif _inside_call[fn] == 2: - out = _torch_functional[name](*args, **kwargs) - elif _inside_call[fn] == 3: - out = _torch[name](*args, **kwargs) - else: - raise ValueError - _inside_call[fn] -= 1 - return out - - return _fn +from truegrad.functional import (add, chunk, convnd, einsum, mul, reshape, simple_wrap, + split) def no_parameter(fn: Callable): - @functools.partial(call_torch, name=fn.__name__) def _fn(*args, **kwargs): for i, arg in enumerate(args): if isinstance(arg, nn.Parameter): @@ -208,7 +185,6 @@ def alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace return F.alpha_dropout(input, p, training, inplace) -@call_torch def batch_norm(input: Tensor, running_mean: typing.Optional[Tensor], running_var: typing.Optional[Tensor], weight: typing.Optional[Tensor] = None, bias: typing.Optional[Tensor] = None, training: bool = False, momentum: float = 0.1, @@ -221,7 +197,6 @@ def batch_norm(input: Tensor, running_mean: typing.Optional[Tensor], return input -@call_torch def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: typing.Optional[Tensor] = None): batch_dims = ''.join(chr(ord('a') + i) for i in range(input1.ndim - 1)) x = einsum(f'{batch_dims}x,{batch_dims}y,zxy->{batch_dims}z', input1, input2, weight) @@ -260,58 +235,24 @@ def channel_shuffle(input: Tensor, groups: int): return F.channel_shuffle(input, groups) -class _ConvNdFn(torch.autograd.Function): - @staticmethod - def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], args) -> torch.Tensor: - if weight.requires_grad: - ctx.save_for_backward(input, weight, bias) - ctx.args = args - dim = input.dim() - 2 # Batch, Feature, *Data - return getattr(F, f"conv{dim}d")(input, weight, bias, *args) - - @staticmethod - def backward(ctx, dy: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], None]: - if not ctx.saved_tensors: - return None, None, None, None - inp, wgt, bias = ctx.saved_tensors - dim = inp.dim() - 2 - summed = [0] + list(range(2, 2 + dim)) - - dx = getattr(grad, f"conv{dim}d_input")(inp.size(), wgt, dy, *ctx.args) - dw = getattr(grad, f"conv{dim}d_weight")(inp, wgt.size(), dy, *ctx.args) - db = None if bias is None else dy.sum(summed) - - if isinstance(wgt, nn.Parameter) or isinstance(bias, nn.Parameter): - dy_sq = dy.square() * dy.size(0) - if isinstance(wgt, nn.Parameter): - wgt.sum_grad_squared = getattr(grad, f"conv{dim}d_weight")(inp.square(), wgt.size(), dy_sq, *ctx.args) - if isinstance(bias, nn.Parameter): - bias.sum_grad_squared = dy_sq.sum(summed) - return dx, dw, db, None - - -@call_torch def _convnd(input: Tensor, weight: Tensor, bias: Optional[Tensor], dim: int, *args): if input.dim() != dim + 2: raise ValueError(f"Input has {input.dim()} dimensions, but expected {dim + 2} dimensions for conv{dim}d.") - return _ConvNdFn.apply(input, weight, bias, args) + return convnd(input, weight, bias, args) -@call_torch def conv1d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: int = 1, padding: Union[str, int] = "valid", dilation: int = 1, groups: int = 1): return _convnd(input, weight, bias, 1, stride, padding, dilation, groups) -@call_torch def conv2d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: int = 1, padding: Union[str, int] = "valid", dilation: int = 1, groups: int = 1): return _convnd(input, weight, bias, 2, stride, padding, dilation, groups) -@call_torch def conv3d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: int = 1, padding: Union[str, int] = "valid", dilation: int = 1, groups: int = 1): @@ -391,11 +332,12 @@ def elu_(input: Tensor, alpha: float = 1.0): return F.elu_(input, alpha) -@no_parameter def embedding(input: Tensor, weight: Tensor, padding_idx: typing.Optional[int] = None, max_norm: typing.Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False): - return F.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) + args = [input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse] + out = F.embedding(*args) + return simple_wrap(F.embedding, out, *args) @no_parameter @@ -440,7 +382,6 @@ def grid_sample(input: Tensor, grid: Tensor, mode: str = bilinear, padding_mode: return F.grid_sample(input, grid, mode, padding_mode, align_corners) -@call_torch def group_norm(input: Tensor, num_groups: int, weight: typing.Optional[Tensor] = None, bias: typing.Optional[Tensor] = None, eps: float = 1e-05): x = F.group_norm(input, num_groups, None, None, eps) @@ -493,7 +434,6 @@ def huber_loss(input: Tensor, target: Tensor, reduction: str = "mean", delta: fl return F.huber_loss(input, target, reduction, delta) -@call_torch def instance_norm(input: Tensor, running_mean: typing.Optional[Tensor] = None, running_var: typing.Optional[Tensor] = None, weight: typing.Optional[Tensor] = None, bias: typing.Optional[Tensor] = None, use_input_stats: bool = True, momentum: float = 0.1, @@ -526,7 +466,6 @@ def l1_loss(input: Tensor, target: Tensor, size_average: typing.Optional[bool] = return F.l1_loss(input, target, size_average, reduce, reduction) -@call_torch def layer_norm(input: Tensor, normalized_shape: typing.List[int], weight: typing.Optional[Tensor] = None, bias: typing.Optional[Tensor] = None, eps: float = 1e-05, broadcast: bool = True): if broadcast: @@ -549,9 +488,9 @@ def leaky_relu_(input: Tensor, negative_slope: float = 0.01): return F.leaky_relu_(input, negative_slope) -@call_torch def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]): - input = matmul(input, transpose(weight, (0, 1))) + batch_dims = ''.join(chr(ord('a') + i) for i in range(input.ndim - 1)) + input = einsum(f"{batch_dims}y,zy->{batch_dims}z", input, weight) if bias is None: return input return add(input, bias) diff --git a/truegrad/optim.py b/truegrad/optim.py index 28dffc0..b514289 100644 --- a/truegrad/optim.py +++ b/truegrad/optim.py @@ -80,7 +80,7 @@ def step(self, closure=None): adam_update = exp_avg / (exp_avg_sq / (1 - beta2 ** step)).sqrt().add_(group['eps']) if group["graft"] and not do_adam: - alpha = alpha * adam_update.norm() / update.norm() + alpha = alpha * adam_update.norm() / update.norm().add_(group['eps']) elif do_adam: update = adam_update diff --git a/truegrad/utils.py b/truegrad/utils.py index 0466cda..48c25d8 100644 --- a/truegrad/utils.py +++ b/truegrad/utils.py @@ -1,4 +1,8 @@ +import collections +import typing + import torch +from torch import overrides import truegrad from truegrad.nn import TrueGradParameter @@ -16,6 +20,26 @@ def _apply_fn(module: torch.nn.Module): _apply_fn(mod) +def from_x(name: str, fn: typing.Callable, module): + calls = [0] + original = getattr(module, name) + + def _fn(*args, **kwargs): + calls[0] += 1 + if calls[0] == 1: + try: + return fn(*args, **kwargs) + except: + return original(*args, **kwargs) + finally: + calls[0] -= 1 + out = original(*args, **kwargs) + calls[0] -= 1 + return out + + return _fn + + def _patch(tg, th): tg_dir = dir(tg) for name in dir(th): @@ -26,10 +50,11 @@ def _patch(tg, th): continue if item.__module__ != tg.__name__: continue - setattr(th, name, item) + setattr(th, name, from_x(name, item, th)) def patch_torch(): _patch(truegrad.nn.functional, torch.nn.functional) _patch(truegrad.nn.functional, torch) _patch(truegrad.nn, torch.nn) + overrides.has_torch_function_variadic = lambda *x: False