Skip to content

Commit

Permalink
dynamic_learning_rate, support epoch arg
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 3, 2023
1 parent 4e76da6 commit 88cdd46
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def init_train_epoch(self):

# Update learning rate
self._updater.set_learning_rate(self.learning_rate)
self._updater.set_current_train_step(self.global_train_step)
self._updater.set_current_train_step(global_train_step=self.global_train_step, epoch=self.epoch)

def train_epoch(self):
"""
Expand Down Expand Up @@ -325,7 +325,7 @@ def train_epoch(self):

step_idx += 1
self.global_train_step += 1
self._updater.set_current_train_step(self.global_train_step)
self._updater.set_current_train_step(global_train_step=self.global_train_step, epoch=self.epoch)

elapsed = time.time() - epoch_start_time
elapsed_computation_percentage = elapsed_computation_time / elapsed
Expand Down
6 changes: 4 additions & 2 deletions returnn/torch/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(self, *, config, network, device, initial_learning_rate=1.0):
self.network = network
self._device = device
self._current_train_step = 0
self._current_epoch = 0

self.learning_rate_function = self.config.typed_value("dynamic_learning_rate", None)
if self.learning_rate_function is not None:
Expand Down Expand Up @@ -158,17 +159,18 @@ def _update_effective_learning_rate(self):
self._effective_learning_rate = self.learning_rate
if self.learning_rate_function is not None:
self._effective_learning_rate = self.learning_rate_function(
global_train_step=self._current_train_step, learning_rate=self.learning_rate
global_train_step=self._current_train_step, epoch=self._current_epoch, learning_rate=self.learning_rate
)
if self.optimizer:
for param_group in self.optimizer.param_groups:
param_group["lr"] = self._effective_learning_rate

def set_current_train_step(self, global_train_step):
def set_current_train_step(self, *, global_train_step: int, epoch: int):
"""
Obtains an updated learning rate for the current training step inside a (sub)epoch.
"""
self._current_train_step = global_train_step
self._current_epoch = epoch
self._update_effective_learning_rate()

def update_params(self, *, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None):
Expand Down

0 comments on commit 88cdd46

Please sign in to comment.