diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index d9c49dc6bea1d..8c08152986cf6 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -223,7 +223,8 @@ void IExecutionFrame::Init(gsl::span feed_mlvalue_idxs, gsl::span& initializers, const std::function& is_initializer_sparse_func, gsl::span fetches) { - ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size()); + ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size(), "Get feed size: ", feeds.size(), " but expected feed size: ", + feed_mlvalue_idxs.size()); ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size()); // Need this for sparse conversions in host memory diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index a57bddc661459..1b959823e4298 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -438,13 +438,6 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu exported_model = post_process_enabling_autograd_function(exported_model) - if self._runtime_options.enable_mem_efficient_grad_management: - from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training - - exported_model = post_processing_enable_mem_efficient_training( - exported_model, self._flattened_module.named_parameters() - ) - if self._runtime_options.enable_zero_stage3_support: from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat @@ -504,9 +497,29 @@ def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfigurati def _initialize_graph_builder(self): """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder""" + # We post process the exported model because the trainable parame might be changed, so this path is + # re-triggered by reinitialize_graph_builder. + exported_model = copy.deepcopy(self._onnx_models.exported_model) + self._onnx_models.processed_exported_model = exported_model + if self._runtime_options.enable_mem_efficient_grad_management: + from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training + + # Override the options if model is not modified. + ( + self._runtime_options.enable_mem_efficient_grad_management, + exported_model, + ) = post_processing_enable_mem_efficient_training(exported_model, self._flattened_module.named_parameters()) + + # if self._runtime_options.run_symbolic_shape_infer: + # exported_model = SymbolicShapeInference.infer_shapes( + # exported_model, auto_merge=True, guess_output_rank=True + # ) + # All initializer names along with user inputs are a part of the onnx graph inputs # since the onnx model was exported with the flag keep_initializers_as_inputs=True - onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input} + # We need to use the raw exported model here since the graph inputs include both user inputrs and + # parameters. + onnx_initializer_names = {p.name for p in exported_model.graph.input} # TODO: PyTorch exporter bug: changes the initializer order in ONNX model initializer_names = [ @@ -535,6 +548,7 @@ def _initialize_graph_builder(self): # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + grad_builder_config.input_names_require_grad = input_names_require_grad grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization @@ -546,12 +560,23 @@ def _initialize_graph_builder(self): # It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way # and are kept as they appear in the exported onnx model. - self._graph_builder.initialize(self._onnx_models.exported_model.SerializeToString(), grad_builder_config) + self._graph_builder.initialize(exported_model.SerializeToString(), grad_builder_config) + + raw_onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input} + + raw_initializer_names = [ + name for name, _ in self._flattened_module.named_parameters() if name in raw_onnx_initializer_names + ] + raw_initializer_names_to_train = [ + name + for name, param in self._flattened_module.named_parameters() + if param.requires_grad and name in raw_onnx_initializer_names + ] # TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train # a set (unordered_set in the backend) that does not require a copy on each reference. - self._graph_initializer_names = set(initializer_names) - self._graph_initializer_names_to_train = set(initializer_names_to_train) + self._graph_initializer_names = set(raw_initializer_names) + self._graph_initializer_names_to_train = set(raw_initializer_names_to_train) # Initializers can be cached and used since they are expected not to be re-instantiated # between forward calls. @@ -602,7 +627,7 @@ def _enable_conditional_optimizations( # Enable data sparsity inspection if sparse optimizer is ON or user wants to print input density. if self._runtime_options.enable_sparse_optimizer or self._runtime_options.print_input_density: self._runtime_inspector.enable_input_inspector( - self._onnx_models.exported_model, self._graph_builder.get_graph_info().user_input_names + self._onnx_models.processed_exported_model, self._graph_builder.get_graph_info().user_input_names ) if self._runtime_options.enable_sparse_optimizer: @@ -621,7 +646,7 @@ def _enable_conditional_optimizations( from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( - self._flattened_module.named_parameters() + self._flattened_module.named_parameters(), self._onnx_models.exported_model ) else: param_to_append_as_onnx_graph_inputs = self._graph_initializers diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 5a6c1070b7f43..43161cb2e42ed 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -159,21 +159,10 @@ def forward(self, *inputs, **kwargs): # Assert that the input and model device match _utils._check_same_device(self._device, "Input argument to forward", *inputs) - if ( - self._runtime_options.enable_zero_stage3_support - or self._runtime_options.enable_mem_efficient_grad_management - ): + if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, self._device) - param_to_append_as_onnx_graph_inputs = [] - if self._runtime_options.enable_mem_efficient_grad_management: - from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger - - param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( - self._flattened_module.named_parameters() - ) - else: - param_to_append_as_onnx_graph_inputs = self._graph_initializers + param_to_append_as_onnx_graph_inputs = self._graph_initializers prepared_input_list, _, _ = _io._combine_input_buffers_initializers( param_to_append_as_onnx_graph_inputs, diff --git a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py index 31bbda18c00cb..6779ceab52a60 100644 --- a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py +++ b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py @@ -3,9 +3,9 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations import ctypes -from typing import Dict, List, Optional, Tuple, Union import torch from onnx import ModelProto, NodeProto, TensorProto, helper @@ -19,39 +19,45 @@ MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE = [1] -def get_params_connected_to_pull_param_trigger(named_params: Dict[str, torch.nn.parameter.Parameter]): - return {k: v for k, v in named_params if v.requires_grad} +def get_params_connected_to_pull_param_trigger( + named_params: dict[str, torch.nn.parameter.Parameter], exported_model: ModelProto +): + # Be noted, some parameters might not in graph input because they are not used in forward, so we filtered them also. + onnx_initializer_names = {p.name for p in exported_model.graph.input} + return {k: v for k, v in named_params if v.requires_grad and k in onnx_initializer_names} -def get_params_not_connected_to_pull_param_trigger(named_params: Dict[str, torch.nn.parameter.Parameter]): - return [v for k, v in named_params if not v.requires_grad] +def get_params_not_connected_to_pull_param_trigger( + named_params: dict[str, torch.nn.parameter.Parameter], exported_model: ModelProto +): + # Be noted, some parameters might not in graph input because they are not used in forward, so we filtered them also. + onnx_initializer_names = {p.name for p in exported_model.graph.input} + return [v for k, v in named_params if not v.requires_grad and k in onnx_initializer_names] def post_processing_enable_mem_efficient_training( exported_model: ModelProto, - named_params: Dict[str, torch.nn.parameter.Parameter], -) -> ModelProto: + named_params: dict[str, torch.nn.parameter.Parameter], +) -> tuple[bool, ModelProto]: """This function is used to enable zero stage3 compatibility. Args: exported_model (ModelProto): The exported model. named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The full parameter map. + + Returns: + tuple[bool, ModelProto]: A tuple of bool and ModelProto. The bool indicates whether the model is modified. + """ - trainable_named_params = get_params_connected_to_pull_param_trigger(named_params) + trainable_named_params = get_params_connected_to_pull_param_trigger(named_params, exported_model) + # print(exported_model.graph.input) + if len(trainable_named_params) == 0: + return False, exported_model # Create weight retrieving function using trainable_named_params. param_pull_trigger_func_class = _create_param_trigger_function(trainable_named_params) param_retrieve_func_class = _create_param_retrieval_function(trainable_named_params) - consumer_map = {} - for node in exported_model.graph.node: - for inp in node.input: - if inp not in consumer_map: - consumer_map[inp] = [] - - if node not in consumer_map[inp]: - consumer_map[inp].append(node) - def _get_param_pull_trigger_name(param_name: str) -> str: return f"pull_{param_name}" @@ -90,9 +96,6 @@ def _get_param_pull_trigger_name(param_name: str) -> str: graph_inputs_to_remove.append(graph_input) - if graph_input.name not in consumer_map: - continue - # Create the param retrieval function for this parameter. node_inputs = [ helper.make_tensor_value_info( @@ -123,6 +126,14 @@ def _get_param_pull_trigger_name(param_name: str) -> str: input_offset += 1 # Delete exported_model.graph.input + + names_to_remove = [input.name for input in graph_inputs_to_remove] + value_infos_to_remove = [ + value_info for value_info in exported_model.graph.value_info if value_info.name in names_to_remove + ] + for value_info in value_infos_to_remove: + exported_model.graph.value_info.remove(value_info) + for input_to_remove in graph_inputs_to_remove: exported_model.graph.input.remove(input_to_remove) @@ -135,13 +146,13 @@ def _get_param_pull_trigger_name(param_name: str) -> str: exported_model.graph.input.insert(offset, inputs[0]) exported_model.graph.node.insert(0, weight_pull_node) - return exported_model + return True, exported_model _PARAM_FUNCTION_INDEX = [0] -def _create_param_trigger_function(trainable_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]]): +def _create_param_trigger_function(trainable_named_params: dict[str, torch.nn.parameter.Parameter]): """This function is used to create a weight retrieving function using trainable_named_params.""" @staticmethod @@ -160,9 +171,9 @@ def backward(ctx, *grad_outputs): @staticmethod def infer_shape( node: NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: param_count = len(trainable_named_params.values()) tensor_output_shapes = [ tensor_input_shapes[0], @@ -186,7 +197,7 @@ def infer_shape( ) -def _create_param_retrieval_function(trainable_named_params: Dict[str, torch.nn.parameter.Parameter]): +def _create_param_retrieval_function(trainable_named_params: dict[str, torch.nn.parameter.Parameter]): """This function is used to create a weight retrieving function using trainable_named_params.""" @staticmethod @@ -205,9 +216,9 @@ def backward(ctx, *grad_outputs): @staticmethod def infer_shape( node: NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: input_pointer_scalars_attr_name = "input_pointer_scalars" found = [attr for attr in node.attribute if attr.name == input_pointer_scalars_attr_name] diff --git a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py index d687bc24384ed..a0001a2f201f1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py +++ b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py @@ -33,6 +33,7 @@ class ONNXModels: """ exported_model: Optional[onnx.ModelProto] = None + processed_exported_model: Optional[onnx.ModelProto] = None optimized_model: Optional[onnx.ModelProto] = None def save_exported_model(self, path, name_prefix, export_mode): diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 35e01d8738049..64de8d929bc1a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -321,8 +321,9 @@ def forward(self, *inputs, **kwargs): from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( - self._flattened_module.named_parameters() + self._flattened_module.named_parameters(), self._onnx_models.exported_model ) + else: param_to_append_as_onnx_graph_inputs = self._graph_initializers @@ -505,10 +506,20 @@ def _reinitialize_graph_builder(self, input_info: _InputInfo): if param.requires_grad and name in self._graph_initializer_names } + if self._runtime_options.enable_mem_efficient_grad_management: + from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME + + # Remove the inputs we added during model post-processing. + existing_require_grad_names = [ + n for n in self._input_info.require_grad_names if n != MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME + ] + else: + existing_require_grad_names = self._input_info.require_grad_names + # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad if ( - input_info.require_grad_names != self._input_info.require_grad_names + input_info.require_grad_names != existing_require_grad_names or initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train ): self._input_info = input_info diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 4d4c6663adc27..feeba3d6672c2 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -400,13 +400,10 @@ def _override_from_env_vars(self): if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1: self.enable_zero_stage3_support = True - if ( - "ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT" in os.environ - and int(os.getenv("ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT")) == 1 - ): - if self.enable_custom_autograd_function: - self.enable_mem_efficient_grad_management = True - else: + if "ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT" in os.environ: + enable_grad_mgmt = int(os.getenv("ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT")) + self.enable_mem_efficient_grad_management = enable_grad_mgmt == 1 and self.enable_custom_autograd_function + if not self.enable_custom_autograd_function and enable_grad_mgmt == 1: self._logger.warning( "ORTModule optimization for memory efficient gradient management cannot be enabled " "because PyTorch custom autograd function support is disabled."