diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 1b959823e4298..96b95d51f8b2a 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -147,6 +147,10 @@ def __init__( configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) + # Will be reset everytime we re-initialize the graph builder. + # Be noted, we will never enable this feature for inference mode. + self._mem_efficient_grad_management_is_enabled = False + def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses @@ -497,16 +501,22 @@ def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfigurati def _initialize_graph_builder(self): """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder""" + self._mem_efficient_grad_management_is_enabled = ( + self._export_mode != torch.onnx.TrainingMode.EVAL + and self._runtime_options.enable_mem_efficient_grad_management + ) + # We post process the exported model because the trainable parame might be changed, so this path is # re-triggered by reinitialize_graph_builder. exported_model = copy.deepcopy(self._onnx_models.exported_model) self._onnx_models.processed_exported_model = exported_model - if self._runtime_options.enable_mem_efficient_grad_management: + + if self._mem_efficient_grad_management_is_enabled: from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training # Override the options if model is not modified. ( - self._runtime_options.enable_mem_efficient_grad_management, + self._mem_efficient_grad_management_is_enabled, exported_model, ) = post_processing_enable_mem_efficient_training(exported_model, self._flattened_module.named_parameters()) @@ -543,7 +553,7 @@ def _initialize_graph_builder(self): # Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph. input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) - if self._runtime_options.enable_mem_efficient_grad_management: + if self._mem_efficient_grad_management_is_enabled: from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. @@ -635,14 +645,11 @@ def _enable_conditional_optimizations( inputs, kwargs ) - if ( - self._runtime_options.enable_zero_stage3_support - or self._runtime_options.enable_mem_efficient_grad_management - ): + if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: self._append_pull_weight_trigger_as_input(kwargs, detected_device) param_to_append_as_onnx_graph_inputs = [] - if self._runtime_options.enable_mem_efficient_grad_management: + if self._mem_efficient_grad_management_is_enabled: from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( @@ -697,7 +704,7 @@ def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.devic device=device, ).requires_grad_() - if self._runtime_options.enable_mem_efficient_grad_management: + if self._mem_efficient_grad_management_is_enabled: from ._mem_efficient_grad_mgmt import ( MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME, MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 64de8d929bc1a..cc533e549db92 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -310,14 +310,11 @@ def forward(self, *inputs, **kwargs): self._gradient_accumulation_manager.maybe_update_cache_before_run() - if ( - self._runtime_options.enable_zero_stage3_support - or self._runtime_options.enable_mem_efficient_grad_management - ): + if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: self._append_pull_weight_trigger_as_input(kwargs, self._device) param_to_append_as_onnx_graph_inputs = [] - if self._runtime_options.enable_mem_efficient_grad_management: + if self._mem_efficient_grad_management_is_enabled: from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( @@ -506,7 +503,7 @@ def _reinitialize_graph_builder(self, input_info: _InputInfo): if param.requires_grad and name in self._graph_initializer_names } - if self._runtime_options.enable_mem_efficient_grad_management: + if self._mem_efficient_grad_management_is_enabled: from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME # Remove the inputs we added during model post-processing.