diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 6555d64833158..dfaac5f0fa836 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -428,8 +428,9 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu [name for name, _ in self._flattened_module.named_parameters()], ) - # Cannot append pull weight trigger name to input names here, otherwise, the later check find - # input info mismatch, will re-initialize the graph builder. + # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( + # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) + # find input info mismatch, will re-initialize the graph builder. # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) # Cache model for future runs