From 6ce57e11c64e49812da207beeeccdd7a5c5050ed Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 24 Nov 2023 15:47:00 +0000 Subject: [PATCH] PT updater, rename step, more consistent --- returnn/torch/engine.py | 2 +- returnn/torch/updater.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index 1af2e4a7ab..bc5c35b27c 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -346,7 +346,7 @@ def train_epoch(self): # only update the weights when every gradient accumulation loop ends if (step_idx % self._accum_grad_multiple_step) == (self._accum_grad_multiple_step - 1): - self._updater.update_params(grad_scaler=self._grad_scaler) + self._updater.step(grad_scaler=self._grad_scaler) elapsed_computation_time += time.time() - step_begin_time diff --git a/returnn/torch/updater.py b/returnn/torch/updater.py index dcf8e099f3..7bf1f4c442 100644 --- a/returnn/torch/updater.py +++ b/returnn/torch/updater.py @@ -177,7 +177,7 @@ def set_current_train_step(self, *, global_train_step: int, epoch: int): self._current_epoch = epoch self._update_effective_learning_rate() - def update_params(self, *, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None): + def step(self, *, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None): """ Perform one step, i.e. update the parameters using the optimizer given the current calculated gradients. """