From 95478f8fcdccea49a0417fe5c001b360b55222ef Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 24 Nov 2023 16:40:18 +0000 Subject: [PATCH] PT gradient_noise --- returnn/torch/updater.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/returnn/torch/updater.py b/returnn/torch/updater.py index 7bf1f4c442..36eb7fa866 100644 --- a/returnn/torch/updater.py +++ b/returnn/torch/updater.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Optional, Union, Any, Type, Sequence, Set, Dict, List, Tuple +from typing import Optional, Union, Any, Type, Sequence, Iterable, Set, Dict, List, Tuple import os import gc import torch @@ -118,11 +118,11 @@ def __init__(self, *, config, network, device, initial_learning_rate=1.0): self.optimizer = None # type: typing.Optional[torch.optim.Optimizer] self._grad_clip_global_norm = self.config.float("gradient_clip_global_norm", 0.0) + self._grad_noise = self.config.float("gradient_noise", 0.0) # Check other options we have in TF updater, which we might support here later as well, # but currently do not support. for opt_name in [ - "gradient_noise", "gradient_clip", "gradient_clip_norm", "gradient_clip_avg_norm", @@ -183,6 +183,9 @@ def step(self, *, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None): """ if self._grad_clip_global_norm: torch.nn.utils.clip_grad_norm_(self.network.parameters(), self._grad_clip_global_norm) + if self._grad_noise: + gradient_noise_(self.network.parameters(), self._grad_noise) + if grad_scaler is not None: grad_scaler.step(self.optimizer) grad_scaler.update() @@ -532,3 +535,15 @@ def _wrap_user_blacklist_wd_modules( assert issubclass(mod, (rf.Module, torch.nn.Module)), f"invalid blacklist_weight_decay_modules {mods!r}" res.append(mod) return tuple(res) + + +def gradient_noise_(params: Iterable[torch.nn.Parameter], std: float): + """ + Add gradient noise to parameters, using a truncated normal distribution. + """ + a, b = -2 * std, 2 * std + for param in params: + if param.requires_grad and param.grad is not None: + noise = torch.empty_like(param.grad) + torch.nn.init.trunc_normal_(noise, std=std, a=a, b=b) + param.grad += noise