diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index d5e0e6b78364c..bc5b79eb8d183 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -141,7 +141,7 @@ def build(self, *inputs_to_loss): return self._loss(*inputs_to_loss) - training_block = _TrainingBlock(loss_block) + training_block = _TrainingBlock(loss_block, loss_input_names) if requires_grad is not None and frozen_params is not None and set(requires_grad).intersection(set(frozen_params)): raise RuntimeError(