From c585bc8fbc7bd3801e2124298cb2dd1a35f6056a Mon Sep 17 00:00:00 2001 From: Peng Wang Date: Sun, 24 Dec 2023 22:06:35 -0800 Subject: [PATCH] remove stage3 related change --- .../ortmodule/_graph_execution_manager.py | 68 +++++++++++++------ .../training/ortmodule/_inference_manager.py | 14 +++- .../python/training/ortmodule/_io.py | 12 ++-- .../ortmodule/_mem_efficient_grad_mgmt.py | 16 ++++- .../training/ortmodule/_training_manager.py | 14 +++- 5 files changed, 92 insertions(+), 32 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 713ff3927681b..07aa53612d463 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -529,11 +529,13 @@ def _initialize_graph_builder(self): # Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph. input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) + if self._runtime_options.enable_mem_efficient_grad_management: from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME - # Add stage3 mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. - input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. + + input_names_require_grad.insert(0, MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) grad_builder_config.input_names_require_grad = input_names_require_grad grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization @@ -613,10 +615,20 @@ def _enable_conditional_optimizations( self._runtime_options.enable_zero_stage3_support or self._runtime_options.enable_mem_efficient_grad_management ): - self._append_pull_weight_trigger_as_input(kwargs, detected_device) + kwargs = self._append_pull_weight_trigger_as_input(kwargs, detected_device) + + param_to_append_as_onnx_graph_inputs = [] + if self._runtime_options.enable_mem_efficient_grad_management: + from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger + + param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( + self._flattened_module.named_parameters() + ) + else: + param_to_append_as_onnx_graph_inputs = self._graph_initializers _, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers( - self._graph_initializers, + param_to_append_as_onnx_graph_inputs, self._graph_builder.get_graph_info().user_input_names, self._input_info, self._flattened_module.named_buffers(), @@ -648,25 +660,43 @@ def _enable_conditional_optimizations( self._runtime_inspector.disable_input_inspector() def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device): - from ._mem_efficient_grad_mgmt import ( - MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME, - MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, - MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, - ) + if self._runtime_options.enable_zero_stage3_support: + from ._zero_stage3_compatibility import ( + STAGE3_PULL_WEIGHT_TRIGGER_NAME, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + ) - new_kwargs = { - MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: torch.zeros( - MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, - dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE), + kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros( + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), device=device, ).requires_grad_() - } - # Then the trigger input will be the first user input. - return { - **new_kwargs, - **kwargs, - } + return kwargs + + if self._runtime_options.enable_mem_efficient_grad_management: + from ._mem_efficient_grad_mgmt import ( + MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + ) + + new_kwargs = { + MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: torch.zeros( + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE), + device=device, + ).requires_grad_() + } + + # Then the trigger input will be the first user input. + return { + **new_kwargs, + **kwargs, + } + + return kwargs def _log_feature_stats(self): if get_rank() != 0: diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 3dd18dfdb2314..f4dc6b04de062 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -163,10 +163,20 @@ def forward(self, *inputs, **kwargs): self._runtime_options.enable_zero_stage3_support or self._runtime_options.enable_mem_efficient_grad_management ): - self._append_pull_weight_trigger_as_input(kwargs, self._device) + kwargs = self._append_pull_weight_trigger_as_input(kwargs, self._device) + + param_to_append_as_onnx_graph_inputs = [] + if self._runtime_options.enable_mem_efficient_grad_management: + from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger + + param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( + self._flattened_module.named_parameters() + ) + else: + param_to_append_as_onnx_graph_inputs = self._graph_initializers prepared_input_list, _, _ = _io._combine_input_buffers_initializers( - self._graph_initializers, + param_to_append_as_onnx_graph_inputs, self._graph_info.user_input_names, self._input_info, self._flattened_module.named_buffers(), diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index e7ea65fa65326..7534cc46a21f1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -260,12 +260,12 @@ def _expand_inputs(current_input, non_none_inputs, name=""): ) # params is a list of all initializers known to the onnx graph - # if zero_stage3_offload_param_map: - # for p in params: - # if p not in zero_stage3_offload_param_map.values(): - # result.append(p) - # else: - # result.extend(params) + if zero_stage3_offload_param_map: + for p in params: + if p not in zero_stage3_offload_param_map.values(): + result.append(p) + else: + result.extend(params) if rt_inspector.memory_ob.is_enabled() and not rt_inspector.memory_ob.symbolic_dim_collecting_completed: rt_inspector.memory_ob.collect_symbolic_dim_values(input_info.dynamic_axes, onnx_input_to_value_map) diff --git a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py index b4b69e9182909..f0bbc625cf6a7 100644 --- a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py +++ b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py @@ -19,6 +19,14 @@ MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE = [1] +def get_params_connected_to_pull_param_trigger(named_params: Dict[str, torch.nn.parameter.Parameter]): + return {k: v for k, v in named_params if v.requires_grad} + + +def get_params_not_connected_to_pull_param_trigger(named_params: Dict[str, torch.nn.parameter.Parameter]): + return [v for k, v in named_params if not v.requires_grad] + + def post_processing_enable_mem_efficient_training( exported_model: ModelProto, named_params: Dict[str, torch.nn.parameter.Parameter], @@ -29,7 +37,7 @@ def post_processing_enable_mem_efficient_training( exported_model (ModelProto): The exported model. named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The full parameter map. """ - trainable_named_params = {k: v for k, v in named_params if v.requires_grad} + trainable_named_params = get_params_connected_to_pull_param_trigger(named_params) # Create weight retrieving function using trainable_named_params. param_pull_trigger_func_class = _create_param_trigger_function(trainable_named_params) @@ -75,7 +83,8 @@ def _get_param_pull_trigger_name(param_name: str) -> str: ) graph_inputs_to_remove = [] - for graph_input in reversed(exported_model.graph.input): + input_offset = 0 + for graph_input in exported_model.graph.input: if graph_input.name not in trainable_named_params: continue @@ -110,7 +119,8 @@ def _get_param_pull_trigger_name(param_name: str) -> str: training_mode=1, safe_run_mode=0, ) - exported_model.graph.node.insert(0, new_node) + exported_model.graph.node.insert(input_offset, new_node) + input_offset += 1 # Delete exported_model.graph.input for input_to_remove in graph_inputs_to_remove: diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 00aee0a5e8a75..bca6e88649dfc 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -314,10 +314,20 @@ def forward(self, *inputs, **kwargs): self._runtime_options.enable_zero_stage3_support or self._runtime_options.enable_mem_efficient_grad_management ): - self._append_pull_weight_trigger_as_input(kwargs, self._device) + kwargs = self._append_pull_weight_trigger_as_input(kwargs, self._device) + + param_to_append_as_onnx_graph_inputs = [] + if self._runtime_options.enable_mem_efficient_grad_management: + from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger + + param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( + self._flattened_module.named_parameters() + ) + else: + param_to_append_as_onnx_graph_inputs = self._graph_initializers prepared_input_list, _, _ = _io._combine_input_buffers_initializers( - self._graph_initializers, + param_to_append_as_onnx_graph_inputs, self._graph_info.user_input_names, self._input_info, self._flattened_module.named_buffers(),