diff --git a/mingpt/trainer.py b/mingpt/trainer.py index c0d08521..27385eee 100644 --- a/mingpt/trainer.py +++ b/mingpt/trainer.py @@ -93,10 +93,10 @@ def run(self): logits, self.loss = model(x, y) # backprop and update the parameters - model.zero_grad(set_to_none=True) self.loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) self.optimizer.step() + model.zero_grad(set_to_none=True) self.trigger_callbacks('on_batch_end') self.iter_num += 1