diff --git a/returnn/torch/updater.py b/returnn/torch/updater.py index d3d0cf71d4..ad68fedeb8 100644 --- a/returnn/torch/updater.py +++ b/returnn/torch/updater.py @@ -106,10 +106,11 @@ def __init__(self, *, config, network, device, initial_learning_rate=1.0): ), "please specify **kwargs in dynamic_learning_rate for future compatibility" else: raise NotImplementedError("not implemented for not callable dynamic_learning_rate") - self._update_effective_learning_rate() self.optimizer = None # type: typing.Optional[torch.optim.Optimizer] + self._update_effective_learning_rate() + def set_learning_rate(self, value): """ Updates the learning rate of the optimizer at each (sub)epoch.