Skip to content

Commit

Permalink
PT gradient_noise
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 24, 2023
1 parent 6ce57e1 commit 95478f8
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions returnn/torch/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Optional, Union, Any, Type, Sequence, Set, Dict, List, Tuple
from typing import Optional, Union, Any, Type, Sequence, Iterable, Set, Dict, List, Tuple
import os
import gc
import torch
Expand Down Expand Up @@ -118,11 +118,11 @@ def __init__(self, *, config, network, device, initial_learning_rate=1.0):
self.optimizer = None # type: typing.Optional[torch.optim.Optimizer]

self._grad_clip_global_norm = self.config.float("gradient_clip_global_norm", 0.0)
self._grad_noise = self.config.float("gradient_noise", 0.0)

# Check other options we have in TF updater, which we might support here later as well,
# but currently do not support.
for opt_name in [
"gradient_noise",
"gradient_clip",
"gradient_clip_norm",
"gradient_clip_avg_norm",
Expand Down Expand Up @@ -183,6 +183,9 @@ def step(self, *, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None):
"""
if self._grad_clip_global_norm:
torch.nn.utils.clip_grad_norm_(self.network.parameters(), self._grad_clip_global_norm)
if self._grad_noise:
gradient_noise_(self.network.parameters(), self._grad_noise)

if grad_scaler is not None:
grad_scaler.step(self.optimizer)
grad_scaler.update()
Expand Down Expand Up @@ -532,3 +535,15 @@ def _wrap_user_blacklist_wd_modules(
assert issubclass(mod, (rf.Module, torch.nn.Module)), f"invalid blacklist_weight_decay_modules {mods!r}"
res.append(mod)
return tuple(res)


def gradient_noise_(params: Iterable[torch.nn.Parameter], std: float):
"""
Add gradient noise to parameters, using a truncated normal distribution.
"""
a, b = -2 * std, 2 * std
for param in params:
if param.requires_grad and param.grad is not None:
noise = torch.empty_like(param.grad)
torch.nn.init.trunc_normal_(noise, std=std, a=a, b=b)
param.grad += noise

0 comments on commit 95478f8

Please sign in to comment.