diff --git a/README.md b/README.md index 852d48d..2c1ce20 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,62 @@ python3 -m pip install truegrad ## Examples +### BackPack + +The preferred method to integrate TrueGrad is using [BackPack](https://github.com/f-dangel/backpack). BackPack is a +third-party library that automatically computes the sum of gradient squares and works for most models by implementing +custom backward rules for many `torch.nn.Module`'s. + +```PYTHON +import backpack +import torch +from torch.nn import CrossEntropyLoss +from truegrad.optim import TGAdamW +from torchvision.models import alexnet + +model = alexnet() +optim = TGAdamW(model.parameters(), lr=1e-7, weight_decay=0) + +# backpack can't handle inplace ops like nn.ReLU(inplace=True) and `x += y` +for mod in model.modules(): + if hasattr(mod, "inplace"): + mod.inplace = False + +# backpack relies on module-level pytorch hooks +model = backpack.extend(model) +lossfunc = backpack.extend(CrossEntropyLoss()) + +# constant input/output to overfit +inp = torch.randn((2, 3, 224, 224)) +tgt = torch.randint(0, 1000, (2,)) + +# standard training loop +i = 0 +while True: + # "SumGradSquared" computes the sum of the squared gradient + with backpack.backpack(backpack.extensions.SumGradSquared()): + loss = lossfunc(model(inp), tgt) + loss.backward() + optim.step() + i += 1 + if i % 5 == 0: + print(i, loss.item()) +``` + +If you're using custom modules with self-defined parameters, this method will not work. Additionally, note that, if +your model has any layer called `.output` or you're using PyTorch >= 1.13, you will need to install +[BackPack-HF](https://github.com/ClashLuke/backpack-hf) via +`python3 -m pip install git+https://github.com/ClashLuke/backpack-hf`. + ### Patch Custom Models -The easiest way to integrate TrueGrad into existing models is to patch them using `truegrad.utils.patch_model()`. +Another option to integrate TrueGrad into existing models is to patch them using `truegrad.utils.patch_model()`. `patch_model()` will go through all`torch.nn.Module`'s in PyTorch model and convert their `torch.nn.Parameter`'s to `truegrad.nn.TrueGradParameter`'s. A `TrueGradParameter` acts largely the same as a `torch.nn.Parameter`, but adds required operations into the model's backward pass.\ -Patching an existing +Importantly, be aware that this does not work for fused functions, such as `torch.nn.LayerNorm` +and `torch.nn.MultiheadAttention`. However, unfused functions which directly access a parameter, such as multiplication +and work well. Therefore, torch.nn.Linear and HuggingFace's attention work as expected. ```PYTHON import transformers @@ -40,11 +89,13 @@ for sample in ["Hello", "World", "!"]: ### nn -Patching existing PyTorch computation graphs on the fly might add unnecessary memory and computation. That's why a -pre-patched alternative of `torch.nn` with hand-crafted gradients exists alongside the `truegrad.utils` module. Compared -to `truegrad.utils.patch_model()`, `truegrad.nn` offers higher speeds and lower memory usage, although it might require -code alterations and doesn't support all models. You cannot (currently) use `truegrad.nn` with `truegrad.utils`, as both -use different ways to arrive at the same value. +Patching existing PyTorch computation graphs on the fly might add unnecessary memory and computation or even fail +unexpectedly. That's why a pre-patched alternative of `torch.nn` with hand-crafted gradients exists alongside the +`truegrad.utils` module. Compared to `truegrad.utils.patch_model()`, `truegrad.nn` offers higher speeds and lower +memory usage, although it might require code alterations and doesn't support all models. You cannot (currently) use +`truegrad.nn` with `truegrad.utils`, as both use different ways to arrive at the same value. However, you can +combine `torch.nn.Modules` and `truegrad.nn.Modules` and use the truegrad information only where it is available ( +see [Partial TrueGrad](#Partial-TrueGrad)). ```PYTHON import torch @@ -58,7 +109,7 @@ model = torch.nn.Sequential(nn.Linear(1, 10), nn.Linear(10, 1)) optim = TGAdamW(model.parameters()) # truegrad.optim.TGAdamW instead of torch.optim.AdamW -# training loop as normal +# standard training loop while True: input = torch.randn((16, 1)) model(input).mean().backward() @@ -67,10 +118,11 @@ while True: ### Partial TrueGrad -Unfortunately, it's not always sensible to apply TrueGrad, as some backward passes are too slow to do them twice. -Therefore, it can be an option to use TGAdamW only on specific subsections of the model. To do so, you can either check -which parameters are of type `truegrad.nn.TrueGradParameter` when using `truegrad.utils.patch_model()` or which -parameters belong to a module listed in `truegrad.nn.modules`. +Unfortunately, it's not always sensible to apply TrueGrad, as some backward passes are too slow, and sometimes it's +impossible to avoid a fused function. +Therefore, it can be an option to use TGAdamW only on specific subsections of the model. To do so, you can +specify `default_to_adam=True` to TGAdamW. Adding this option allows TGAdamW to fall back to AdamW if there is +no `sum_grad_squared` attribute available. For example, the code from [#nn](#nn) could be extended in the following way: ```PYTHON @@ -83,26 +135,11 @@ model = torch.nn.Sequential(nn.Linear(1, 10), # Weights coming from truegrad.nn torch.nn.ReLU(), torch.nn.Linear(10, 1)) # Weights coming torch.nn -truegrad_parameters = [] -normal_parameters = [] - - -def get_parameters(mod: torch.nn.Module): - if isinstance(mod, nn.modules): - truegrad_parameters.extend(list(mod.parameters(recurse=False))) - else: - # you could do truegrad.utils.patch_model(mod, recurse=False) here! - normal_parameters.extend(list(mod.parameters(recurse=False))) - - -model = model.apply(get_parameters) - -optim0 = TGAdamW(truegrad_parameters) -optim1 = torch.optim.AdamW(normal_parameters) +optim = TGAdamW(model.parameters(), default_to_adam=True) +# standard training loop while True: input = torch.randn((16, 1)) model(input).mean().backward() - optim0.step() # update both parameter sets separately - optim1.step() + optim.step() ``` \ No newline at end of file diff --git a/setup.py b/setup.py index 86d44d1..33c76d5 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ name='truegrad', license='BSD', description='PyTorch interface for TrueGrad-AdamW', - version='0.0.9', + version='0.1.0', long_description=README, url='https://github.com/clashluke/truegrad', packages=setuptools.find_packages(), diff --git a/truegrad/functional.py b/truegrad/functional.py index 8e4fd72..7896c36 100644 --- a/truegrad/functional.py +++ b/truegrad/functional.py @@ -18,11 +18,11 @@ def backward(ctx, dy: torch.Tensor): diff = inp.ndim - weight.ndim summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1] weight_grad = dy * inp - weight.square_grad = weight_grad.square() + weight.sum_grad_squared = weight_grad.square() if summed: weight_grad = weight_grad.sum(summed) - weight.square_grad = weight.square_grad.sum(summed) - weight.square_grad = weight.square_grad.reshape(weight.size()) * dy.size(0) + weight.sum_grad_squared = weight.sum_grad_squared.sum(summed) + weight.sum_grad_squared = weight.sum_grad_squared.reshape(weight.size()) * dy.size(0) return dy * weight, weight_grad.reshape(weight.size()) @@ -42,11 +42,11 @@ def backward(ctx, dy: torch.Tensor): return None, None weight, = ctx.saved_tensors weight_grad = dy - weight.square_grad = dy.square() + weight.sum_grad_squared = dy.square() if ctx.summed: weight_grad = weight_grad.sum(ctx.summed) - weight.square_grad = weight.square_grad.sum(ctx.summed) - weight.square_grad = weight.square_grad.reshape(weight.size()) * dy.size(0) + weight.sum_grad_squared = weight.sum_grad_squared.sum(ctx.summed) + weight.sum_grad_squared = weight.sum_grad_squared.reshape(weight.size()) * dy.size(0) return dy, weight_grad.reshape(weight.size()) @@ -67,7 +67,7 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor, torch.Tensor]: lhs, rhs = inputs.split(',') d_wgt = torch.einsum(f'{lhs},{output}->{rhs}', inp, dy) - wgt.square_grad = torch.einsum(f'{lhs},{output}->{rhs}', inp.square(), dy.square() * inp.size(0)) + wgt.sum_grad_squared = torch.einsum(f'{lhs},{output}->{rhs}', inp.square(), dy.square() * inp.size(0)) d_inp = torch.einsum(f"{rhs},{output}->{lhs}", wgt, dy) return None, d_inp, d_wgt @@ -85,7 +85,7 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]: return None, None inp, wgt = ctx.saved_tensors wgt_grad = torch.zeros_like(wgt) - wgt.square_grad = wgt_grad.scatter_add(0, inp, dy.square()) + wgt.sum_grad_squared = wgt_grad.scatter_add(0, inp, dy.square()) wgt_grad.scatter_add_(0, inp, dy) return None, wgt_grad @@ -103,8 +103,8 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]: if not ctx.saved_tensors: return None wgt, = ctx.saved_tensors - if hasattr(wgt, "square_grad"): - wgt.square_grad = wgt.square_grad.reshape(ctx.original_shape) + if hasattr(wgt, "sum_grad_squared"): + wgt.sum_grad_squared = wgt.sum_grad_squared.reshape(ctx.original_shape) return dy.reshape(ctx.original_shape) @@ -121,8 +121,8 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]: if not ctx.saved_tensors: return None wgt, = ctx.saved_tensors - if hasattr(wgt, "square_grad") and ctx.summed: - wgt.square_grad = wgt.square_grad.sum(ctx.summed) + if hasattr(wgt, "sum_grad_squared") and ctx.summed: + wgt.sum_grad_squared = wgt.sum_grad_squared.sum(ctx.summed) return dy.sum(ctx.summed) diff --git a/truegrad/nn.py b/truegrad/nn.py index 4b529e8..bb25f88 100644 --- a/truegrad/nn.py +++ b/truegrad/nn.py @@ -159,10 +159,10 @@ def _square(x: Union[torch.Tensor, None]): for p, a in zip(list(ctx.args) + list(ctx.kwargs.values()), list(args) + list(kwargs.values())): if not isinstance(p, torch.nn.Parameter): continue - if hasattr(p, "square_grad") and p.square_grad is not None: - p.square_grad = p.square_grad + a.grad + if hasattr(p, "sum_grad_squared") and p.sum_grad_squared is not None: + p.sum_grad_squared = p.sum_grad_squared + a.grad else: - p.square_grad = a.grad + p.sum_grad_squared = a.grad return None, None, None, None diff --git a/truegrad/optim.py b/truegrad/optim.py index 466addf..28dffc0 100644 --- a/truegrad/optim.py +++ b/truegrad/optim.py @@ -9,9 +9,10 @@ def __init__(self, params, lr: float = 1e-3, eps: float = 1e-12, weight_decay: float = 1e-2, graft: bool = True, - decay_to_init: bool = False): + decay_to_init: bool = False, + default_to_adam: bool = False): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft, - decay_to_init=decay_to_init) + decay_to_init=decay_to_init, default_to_adam=default_to_adam) super(TGAdamW, self).__init__(params, defaults) @torch.no_grad() @@ -31,18 +32,19 @@ def step(self, closure=None): for p in group['params']: if p.grad is None: continue - if not hasattr(p, "square_grad") or p.square_grad is None: - raise ValueError(f"Parameter of shape {list(p.size())} doesn't have `square_grad` attribute. " - f"Make sure to use truegrad.utils.patch_model() or truegrad.nn for all optimized " - f"parameters.") + do_adam = not hasattr(p, "sum_grad_squared") or p.sum_grad_squared is None + if not group["default_to_adam"] and do_adam: + raise ValueError(f"Parameter of shape {list(p.size())} doesn't have `sum_grad_squared` attribute. " + f"Make sure to use backpack.") state = self.state[p] if len(state) == 0: state['step'] = torch.tensor(0.) state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) - state['exp_avg_true_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) - if group["graft"]: + if not do_adam: + state['exp_avg_true_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if do_adam or group["graft"]: state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) if group["decay_to_init"]: state["init"] = torch.clone(p.detach()) @@ -61,22 +63,26 @@ def step(self, closure=None): else: p.mul_(1 - decay) - # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) - exp_avg_true_sq.mul_(beta3).add_(p.square_grad, alpha=1 - beta3) - p.square_grad = None step = step_t.item() - - denom = (exp_avg_true_sq / (1 - beta3 ** step)).sqrt().add_(group['eps']) - update = exp_avg / denom alpha = -group['lr'] / (1 - beta1 ** step) - if group["graft"]: + if not do_adam: + exp_avg_true_sq.mul_(beta3).add_(p.sum_grad_squared, alpha=1 - beta3) + p.sum_grad_squared = None + denom = (exp_avg_true_sq / (1 - beta3 ** step)).sqrt().add_(group['eps']) + update = exp_avg / denom + + if group["graft"] or do_adam: exp_avg_sq = state['exp_avg_sq'] exp_avg_sq.mul_(beta2).add_(p.grad.square(), alpha=1 - beta2) adam_update = exp_avg / (exp_avg_sq / (1 - beta2 ** step)).sqrt().add_(group['eps']) + + if group["graft"] and not do_adam: alpha = alpha * adam_update.norm() / update.norm() + elif do_adam: + update = adam_update p.add_(update, alpha=alpha) return loss