diff --git a/onmt/trainer.py b/onmt/trainer.py index 1b30d0e729..6916ec3ba9 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -401,6 +401,8 @@ def validate(self, valid_iter, moving_average=None): ) # Compute loss. + if self.zero_out_prompt_loss: + batch = self.valid_loss.ignore_prompt(batch) _, batch_stats = self.valid_loss(batch, model_out, attns) stats.update(batch_stats)