Skip to content

Commit

Permalink
PT updater, rename step, more consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 24, 2023
1 parent 8714db3 commit 6ce57e1
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def train_epoch(self):

# only update the weights when every gradient accumulation loop ends
if (step_idx % self._accum_grad_multiple_step) == (self._accum_grad_multiple_step - 1):
self._updater.update_params(grad_scaler=self._grad_scaler)
self._updater.step(grad_scaler=self._grad_scaler)

elapsed_computation_time += time.time() - step_begin_time

Expand Down
2 changes: 1 addition & 1 deletion returnn/torch/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def set_current_train_step(self, *, global_train_step: int, epoch: int):
self._current_epoch = epoch
self._update_effective_learning_rate()

def update_params(self, *, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None):
def step(self, *, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None):
"""
Perform one step, i.e. update the parameters using the optimizer given the current calculated gradients.
"""
Expand Down

0 comments on commit 6ce57e1

Please sign in to comment.