From 4e59594d28ccff1de5fca5d02e2f532ab867473b Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Thu, 24 Aug 2023 16:39:24 +0000 Subject: [PATCH 1/9] rebase main --- .../_custom_autograd_function_exporter.py | 22 +- .../ortmodule/_graph_execution_manager.py | 49 +++- .../training/ortmodule/_inference_manager.py | 4 + .../python/training/ortmodule/_io.py | 8 +- .../training/ortmodule/_training_manager.py | 4 + .../ortmodule/_zero_stage3_compatibility.py | 272 ++++++++++++++++++ .../torch_custom_function_kernel_base.cc | 6 +- 7 files changed, 344 insertions(+), 21 deletions(-) create mode 100644 orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 4c72b6d98a088..31bf1c60d0515 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -28,7 +28,8 @@ class PythonOpShapeInferStore: @classmethod def register(cls, kclass: torch.autograd.Function) -> None: - """Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined. + """Register a shape inference function for a torch.autograd.Function if there is staticmethod + "infer_shape" defined. The signature of the shape inference function should be: @staticmethod @@ -51,6 +52,11 @@ def infer_shape( if hasattr(kclass, "infer_shape") and kclass_name not in cls._CLASS_MAP: cls._CLASS_MAP[kclass_name] = kclass.infer_shape + @classmethod + def register_func(cls, name: str, func: Callable) -> None: + """Register a shape inference function for a torch.autograd.Function by name.""" + cls._CLASS_MAP[name] = func + @classmethod def get_shape_infer(cls, name: str) -> Optional[Callable]: return cls._CLASS_MAP.get(name, None) @@ -307,14 +313,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): _export = wrap_custom_export_function(_export_pt_1_10) -def _post_process_after_export(exported_model: ModelProto, enable_custom_autograd_function: bool) -> ModelProto: - """Post process the exported model.""" - if enable_custom_autograd_function: - exported_model = _post_process_enabling_autograd_function(exported_model) - return exported_model - - -def _post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: +def post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: # Loop all PythonOp, append "_ctx" as the first output. index = 0 for node in exported_model.graph.node: @@ -330,8 +329,7 @@ def _post_process_enabling_autograd_function(exported_model: ModelProto) -> Mode op_name_prefix = kclass_name break - if not node.name: - node.name = f"{op_name_prefix}_id_{index}" - index += 1 + node.name = f"{op_name_prefix}_id_{index}" + index += 1 return exported_model diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 2227b630aee23..f961df5e53b12 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -23,7 +23,6 @@ from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils -from ._custom_autograd_function_exporter import _post_process_after_export from ._fallback import ( ORTModuleDeviceException, ORTModuleONNXModelException, @@ -141,8 +140,13 @@ def __init__( register_triton_op_executor() + self._zero_stage3_param_map = None if self._runtime_options.enable_zero_stage3_support: # Cannot toggle feature enabling/disabling after the first time enabled. + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_offloaded_params + + self._zero_stage3_param_map = _get_all_offloaded_params(self._flattened_module) + configure_ort_compatible_zero_stage3() def _get_torch_gpu_allocator_function_addresses(self): @@ -345,7 +349,8 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu ) if os.path.exists(cache_dir) and os.path.isfile(filename): self._logger.info( - f"Cached model detected! Cached model will be used to save export and initialization time. If you want the model to be re-exported then DELETE {filename}." + f"Cached model detected! Cached model will be used to save export and initialization time." + f"If you want the model to be re-exported then DELETE {filename}." ) exported_model = onnx.load(filename) return exported_model @@ -409,9 +414,23 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu ) exported_model = onnx.load_model_from_string(f.getvalue()) - exported_model = _post_process_after_export( - exported_model, self._runtime_options.enable_custom_autograd_function - ) + if self._runtime_options.enable_custom_autograd_function: + from ._custom_autograd_function_exporter import post_process_enabling_autograd_function + + exported_model = post_process_enabling_autograd_function(exported_model) + + if self._runtime_options.enable_zero_stage3_support: + from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat + + exported_model = post_processing_enable_zero_stage3_compat( + exported_model, + self._zero_stage3_param_map, + [name for name, _ in self._flattened_module.named_parameters()], + ) + + # Cannot append pull weight trigger name to input names here, otherwise, the later check find + # input info mismatch, will re-initialize the graph builder. + # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) # Cache model for future runs if cache_dir: @@ -477,7 +496,14 @@ def _initialize_graph_builder(self): grad_builder_config = C.OrtModuleGraphBuilderConfiguration() grad_builder_config.initializer_names = initializer_names grad_builder_config.initializer_names_to_train = initializer_names_to_train - grad_builder_config.input_names_require_grad = self._input_info.require_grad_names + + input_names_require_grad = self._input_info.require_grad_names + if self._runtime_options.enable_zero_stage3_support: + from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME + + # Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph. + input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) + grad_builder_config.input_names_require_grad = input_names_require_grad grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel( @@ -553,6 +579,9 @@ def _enable_conditional_optimizations( inputs, kwargs ) + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, detected_device) + _, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_builder.get_graph_info().user_input_names, @@ -562,6 +591,7 @@ def _enable_conditional_optimizations( kwargs, detected_device, self._runtime_inspector, + self._zero_stage3_param_map, ) # Enable sparsity-based optimization when applicable. @@ -587,6 +617,13 @@ def _enable_conditional_optimizations( if self._runtime_options.print_memory_stat: self._runtime_inspector.enable_memory_inspector(self._original_module) + def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device): + from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME + + kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros([1], dtype=torch.float32, device=device).requires_grad_() + + return kwargs + def _log_feature_stats(self): if get_rank() != 0: return diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index b7c01a1f5baf9..8d8be81c549d1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -159,6 +159,9 @@ def forward(self, *inputs, **kwargs): # Assert that the input and model device match _utils._check_same_device(self._device, "Input argument to forward", *inputs) + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, self._device) + prepared_input_list, _, _ = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_info.user_input_names, @@ -168,6 +171,7 @@ def forward(self, *inputs, **kwargs): kwargs, self._device, self._runtime_inspector, + self._zero_stage3_param_map, ) user_outputs, _ = InferenceManager.execution_session_run_forward( diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 18b965c549645..e7c1b30daae0d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -168,6 +168,7 @@ def _combine_input_buffers_initializers( kwargs: Mapping[str, ORTModelInputOutputType], device: torch.device, rt_inspector: RuntimeInspector, + zero_stage3_offload_param_map: Optional[Dict[str, torch.nn.parameter.Parameter]], ): """Creates forward `*inputs` list from user input and PyTorch initializers @@ -254,7 +255,12 @@ def _expand_inputs(current_input, non_none_inputs, name=""): ) # params is a list of all initializers known to the onnx graph - result.extend(params) + if zero_stage3_offload_param_map: + for p in params: + if p not in zero_stage3_offload_param_map.values(): + result.append(p) + else: + result.extend(params) return result, embed_sparsity_results, label_sparsity_results diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 3be4c05797978..19effe2086e0a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -311,6 +311,9 @@ def forward(self, *inputs, **kwargs): self._gradient_accumulation_manager.maybe_update_cache_before_run() + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, self._device) + prepared_input_list, _, _ = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_info.user_input_names, @@ -320,6 +323,7 @@ def forward(self, *inputs, **kwargs): kwargs, self._device, self._runtime_inspector, + self._zero_stage3_param_map, ) outputs = unflatten_user_output( diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py new file mode 100644 index 0000000000000..c4ae6b7adb31f --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -0,0 +1,272 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from typing import Dict, List, Optional, Tuple, Union + +import torch +from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper + +from onnxruntime.capi._pybind_state import register_torch_autograd_function + +from ._custom_autograd_function_exporter import PythonOpShapeInferStore +from ._utils import get_fully_qualified_class_name + +STAGE3_PULL_WEIGHT_TRIGGER_NAME = "pull_weight_trigger" + + +def post_processing_enable_zero_stage3_compat( + exported_model: ModelProto, + offload_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]], + param_names: List[str], +) -> ModelProto: + """This function is used to enable zero stage3 compatibility. + + Args: + exported_model (ModelProto): The exported model. + offload_named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The offload named parameters. + param_names (List[str]): All parameter names. + """ + + # Register symbolic shape inference functions for PythonOp used in DeepSpeed ZeRO stage3. + _zero_stage3_register_symbolic_functions() + + # Create weight retrieving function using offload_named_params. + func_full_qual_name = _create_weight_retrieval_function(offload_named_params) + + consumer_map = {} + for node in exported_model.graph.node: + for inp in node.input: + if inp not in consumer_map: + consumer_map[inp] = [] + + if node not in consumer_map[inp]: + consumer_map[inp].append(node) + + def _get_param_pull_trigger_name(param_name: str) -> str: + return f"pull_{param_name}" + + def _get_func_name(node: NodeProto) -> Optional[str]: + for attr in node.attribute: + if attr.name == "func_name": + return attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + return None + + trigger_data_type = TensorProto.FLOAT + trigger_data_dims = [1] + + # Create weight retrieving PythonOp. + new_input, weight_pull_node = _create_weight_retrieval_pythonop( + offload_named_params, + func_full_qual_name, + STAGE3_PULL_WEIGHT_TRIGGER_NAME, + [_get_param_pull_trigger_name(pname) for pname in offload_named_params], + trigger_data_type, + trigger_data_dims, + ) + + # Connect weight consumers to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction. + for graph_input in exported_model.graph.input: + if graph_input.name not in offload_named_params: + continue + + if graph_input.name not in consumer_map: + continue + + consumers = consumer_map[graph_input.name] + pre_forward_pythonop_node = None + + input_tensor_ranks = [] + input_tensor_dtypes = [] + rank_attr = None + dtype_attr = None + for c in consumers: + if c.op_type != "PythonOp": + continue + + func_name = _get_func_name(c) + + if ( + func_name + == "onnxruntime.training.utils.hooks._zero_offload_subscriber.ORTZeROOffloadPreForwardFunction" + ): + assert ( + pre_forward_pythonop_node is None + ), "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen" + pre_forward_pythonop_node = c + for attr in c.attribute: + if attr.name == "input_tensor_ranks": + input_tensor_ranks = attr.ints + rank_attr = attr + if attr.name == "input_tensor_types": + input_tensor_dtypes = attr.ints + dtype_attr = attr + + if pre_forward_pythonop_node is None: + raise RuntimeError( + "Fail to find ORTZeROOffloadPreForwardFunction for partitioned param: " + graph_input.name + ) + + index_offset_on_python_op_input = [] + for i, input_name in enumerate(pre_forward_pythonop_node.input): + if input_name == graph_input.name: + index_offset_on_python_op_input.append(i) + + assert ( + len(index_offset_on_python_op_input) == 1 + ), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input}" + + reverse_index_among_inputs = index_offset_on_python_op_input[0] - len(pre_forward_pythonop_node.input) + pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = _get_param_pull_trigger_name( + graph_input.name + ) + + # For PythonOp, we need to update some of its attributes. + input_tensor_ranks[index_offset_on_python_op_input[0]] = len(trigger_data_dims) + input_tensor_dtypes[index_offset_on_python_op_input[0]] = trigger_data_type + pre_forward_pythonop_node.attribute.remove(rank_attr) + pre_forward_pythonop_node.attribute.remove(dtype_attr) + pre_forward_pythonop_node.attribute.append(helper.make_attribute("input_tensor_ranks", input_tensor_ranks)) + pre_forward_pythonop_node.attribute.append(helper.make_attribute("input_tensor_types", input_tensor_dtypes)) + + output_index = reverse_index_among_inputs + len(pre_forward_pythonop_node.output) + pre_forward_pythonop_node.output[output_index] = graph_input.name + + # Delete exported_model.graph.input + graph_inputs_to_remove = [ + graph_input for graph_input in exported_model.graph.input if graph_input.name in offload_named_params + ] + for input_to_remove in graph_inputs_to_remove: + exported_model.graph.input.remove(input_to_remove) + + # Re-order graph input to make sure the weight pull trigger is before all parameter inputs. + offset = 0 + for graph_input in exported_model.graph.input: + if graph_input.name in param_names: + break + offset += 1 + + exported_model.graph.input.insert(offset, new_input) + exported_model.graph.node.insert(0, weight_pull_node) + + return exported_model + + +def _create_weight_retrieval_function(offload_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]]) -> str: + """This function is used to create a weight retrieving function using offload_named_params.""" + + class WeightRetrievalFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, weight_in_trigger): + params = list(offload_named_params.values()) + ctx.params = params + ctx.dtype = weight_in_trigger.dtype + ctx.device = weight_in_trigger.device + ctx.shape = weight_in_trigger.shape + return (torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype),) * len(params) + + @staticmethod + def backward(ctx, *grad_outputs): + return torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype) + + @staticmethod + def infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + param_count = len(offload_named_params.values()) + tensor_output_shapes = [ + tensor_input_shapes[0], + ] * param_count + tensor_output_dtypes = [ + tensor_input_dtypes[0], + ] * param_count + return tensor_output_shapes, tensor_output_dtypes + + func_full_qual_name = get_fully_qualified_class_name(WeightRetrievalFunction) + register_torch_autograd_function(func_full_qual_name, WeightRetrievalFunction) + PythonOpShapeInferStore.register(WeightRetrievalFunction) + + return func_full_qual_name + + +def _zero_stage3_register_symbolic_functions(): + """This function is used to register symbolic shape inference functions for PythonOp used in + DeepSpeed ZeRO stage3.""" + + def _simple_pass_through_infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes + + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _simple_pass_through_infer_shape + ) + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _simple_pass_through_infer_shape + ) + + def _linear_infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + # output = input.matmul(weight.t()) + tensor_input_shapes[0] # input + shape2 = tensor_input_shapes[1] # weight + output_shape = tensor_input_shapes[0] + output_shape[-1] = shape2[-2] + return [output_shape], [tensor_input_dtypes[0]] + + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape + ) + + +def _create_weight_retrieval_pythonop( + offload_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]], + func_full_qual_name: str, + input_name: str, + output_names: List[str], + trigger_data_type, + trigger_data_dims: List[int], +) -> Tuple[ValueInfoProto, NodeProto]: + """This function is used to create a weight retrieving PythonOp.""" + offload_param_count = 0 if offload_named_params is None else len(offload_named_params) + new_input = helper.make_tensor_value_info(input_name, trigger_data_type, trigger_data_dims) + output_rank_for_pull_weight_trigger = len(trigger_data_dims) + output_dtype_for_pull_weight_trigger = trigger_data_type + output_tensor_ranks = [ + output_rank_for_pull_weight_trigger, + ] * offload_param_count + output_tensor_types = [ + output_dtype_for_pull_weight_trigger, + ] * offload_param_count + + node_attributes = { + "comment": "", + "inplace": 0, + "input_convention": "d", + "input_tensor_ranks": [len(trigger_data_dims)], + "input_tensor_types": [trigger_data_type], + "output_tensor_ranks": output_tensor_ranks, + "output_tensor_types": output_tensor_types, + "training_mode": 1, + "func_name": func_full_qual_name, + } + + weight_pull_node = helper.make_node( + "PythonOp", + [input_name], + ["pull_weight_trigger_ctx", *output_names], + "pull_weight_trigger", # node name + "PythonOp for weight retrieving.", + "com.microsoft", + **node_attributes, + ) + + return new_input, weight_pull_node diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index 4e7fcbc95bb1d..0a3f8139eb020 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -153,8 +153,10 @@ void PythonOpBase::RunForward(OpKernelContext* context, inplace_ != 0, kernel_invoke_id_); - ORT_ENFORCE(1 + returned_ortvalues.size() == static_cast(context->OutputCount()), - "Output count mismatch for PythonOp run"); + const size_t returned_output_count = 1 + returned_ortvalues.size(); + const size_t kernel_output_count = static_cast(context->OutputCount()); + ORT_ENFORCE(returned_output_count == kernel_output_count, "Output count mismatch for PythonOp run, ", + "returned_output_count: ", returned_output_count, ", expected kernel_output_count: ", kernel_output_count); } void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx, std::vector& returned_args) const { From 5f5beec6d576653489979d84380828f2f546ad55 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 25 Aug 2023 12:40:42 +0000 Subject: [PATCH 2/9] Make PyTorch and ORT both work with ZeRoOffloadSubscriber --- .../ortmodule/_graph_execution_manager.py | 4 +- .../ortmodule/_zero_stage3_compatibility.py | 34 ++++++ .../utils/hooks/_subscriber_manager.py | 4 +- .../utils/hooks/_zero_offload_subscriber.py | 101 ++++++++++++------ 4 files changed, 106 insertions(+), 37 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index f961df5e53b12..5cdea97eea202 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -143,9 +143,9 @@ def __init__( self._zero_stage3_param_map = None if self._runtime_options.enable_zero_stage3_support: # Cannot toggle feature enabling/disabling after the first time enabled. - from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_offloaded_params + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params - self._zero_stage3_param_map = _get_all_offloaded_params(self._flattened_module) + self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module) configure_ort_compatible_zero_stage3() diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index c4ae6b7adb31f..7c6116da439fa 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -9,6 +9,7 @@ from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper from onnxruntime.capi._pybind_state import register_torch_autograd_function +from onnxruntime.training.utils import pytorch_dtype_to_onnx from ._custom_autograd_function_exporter import PythonOpShapeInferStore from ._utils import get_fully_qualified_class_name @@ -133,6 +134,39 @@ def _get_func_name(node: NodeProto) -> Optional[str]: output_index = reverse_index_among_inputs + len(pre_forward_pythonop_node.output) pre_forward_pythonop_node.output[output_index] = graph_input.name + # If the consumer of original `graph_input.name` is PythonOp, we need also update its attributes because now + # `graph_input.name` as output of pre_forward_pythonop_node, is full-sized parameter, the rank might differ + # from the original one. + for c in consumers: + if c == pre_forward_pythonop_node or c.op_type != "PythonOp": + continue + + input_tensor_ranks = [] + input_tensor_dtypes = [] + rank_attr = None + dtype_attr = None + for attr in c.attribute: + if attr.name == "input_tensor_ranks": + input_tensor_ranks = attr.ints + rank_attr = attr + if attr.name == "input_tensor_types": + input_tensor_dtypes = attr.ints + dtype_attr = attr + + index_offset_on_python_op_input = [] + for i, input_name in enumerate(c.input): + if input_name == graph_input.name: + index_offset_on_python_op_input.append(i) + + for index in index_offset_on_python_op_input: + input_tensor_ranks[index] = len(offload_named_params[graph_input.name].ds_shape) + input_tensor_dtypes[index] = pytorch_dtype_to_onnx(offload_named_params[graph_input.name].dtype) + + c.attribute.remove(rank_attr) + c.attribute.remove(dtype_attr) + c.attribute.append(helper.make_attribute("input_tensor_ranks", input_tensor_ranks)) + c.attribute.append(helper.make_attribute("input_tensor_types", input_tensor_dtypes)) + # Delete exported_model.graph.input graph_inputs_to_remove = [ graph_input for graph_input in exported_model.graph.input if graph_input.name in offload_named_params diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index db38f58d8f324..5814448960091 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -191,7 +191,7 @@ def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: L next_module_index: list of int, carrying a global unique module index that can be used next. """ module_index = next_module_index[0] - module.id = module_index # STAGE3WARN: needed by DeepSpeed + module.id = module_index # STAGE3WARN#1: needed by DeepSpeed self._run_ctx.global_states.module_index_to_depth[module_index] = depth self._run_ctx.global_states.module_to_module_index[module] = module_index @@ -217,7 +217,7 @@ def _register_hooks_recursively(self, module: torch.nn.Module, depth: int, next_ next_module_index: list of int, carrying a global unique module index that can be used next. """ module_index = next_module_index[0] - module.id = module_index # STAGE3WARN: needed by DeepSpeed + module.id = module_index # STAGE3WARN#2: needed by DeepSpeed self._run_ctx.global_states.module_index_to_depth[module_index] = depth self._run_ctx.global_states.module_to_module_index[module] = module_index diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index 3d42e172eea82..5fb72216f33f0 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -115,13 +115,13 @@ def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.par """ from deepspeed.runtime.zero.partitioned_param_coordinator import iter_params - # Retrive the parameters that are not available for this module. + # Retrieve all parameters for this module. partitioned_params = [param for param in iter_params(module)] return partitioned_params -def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: +def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: """Retrieve all the parameters that are offloaded.""" from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus @@ -134,16 +134,13 @@ def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.par class ORTZeROOffloadPreForwardFunction(torch.autograd.Function): - """This function is a common bridge to call original PyTorch's - pre_forward_function and post_backward_function. - """ + """This function is a common bridge to call original PyTorch's pre_forward_function""" @staticmethod def forward( ctx, module, pre_forward_with_kwargs_function, - post_backward_function, args_schema, kwargs_schema, args_tensor_count, @@ -155,7 +152,6 @@ def forward( ctx: context object module: the module to be called pre_forward_with_kwargs_function: the function to be called before forward (PyTorch's pre_forward_function) - post_backward_function: the function to be called after backward (PyTorch's post_backward_function) args_schema: the schema of the args, used to reconstruct the args in original form in PyTorch's pre_forward_function's inputs. kwargs_schema: the schema of the kwargs, used to reconstruct the kwargs in original form in @@ -168,6 +164,17 @@ def forward( args_tensors = tensor_list[:args_tensor_count] kwargs_tensors = tensor_list[args_tensor_count : args_tensor_count + kwargs_tensor_count] + # For PyTorch runs, the sizes are all 0, it does not need a gradient because + # param._detach().requires_grad_(False) is called. + # But for ORT runs, the sizes are all [1], as output of weight retrieval function. + # So we keep track of the shapes and dtypes of the passed in tensors, then generate the grads in backward. + # While for both PyTorch and ORT runs, the grad is not important because they are not param grads + # any more, they are only used for completing the full backward propagation. + passed_in_param_tensors = tensor_list[args_tensor_count + kwargs_tensor_count :] + ctx.shapes = [p.shape for p in passed_in_param_tensors] + ctx.dtypes = [p.dtype for p in passed_in_param_tensors] + ctx.devices = [p.device for p in passed_in_param_tensors] + args = unflatten_data_using_schema(args_tensors, args_schema) kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema) @@ -179,6 +186,8 @@ def forward( partitioned_params = _get_params_for_current_module(module) ctx.partitioned_params = partitioned_params + assert len(partitioned_params) == len(passed_in_param_tensors) + f_ret = pre_forward_with_kwargs_function(module, args, kwargs) if f_ret is None: @@ -188,7 +197,6 @@ def forward( updated_args, updated_kwargs = f_ret ctx.module = module - ctx.post_backward_function = post_backward_function updated_args_tensors, _ = extract_data_and_schema(updated_args) updated_kwargs_tensors, _ = extract_data_and_schema(updated_kwargs) @@ -203,17 +211,33 @@ def forward( @staticmethod def backward(ctx, *grads): updated_grads = grads - if ctx.post_backward_function is not None: - ret = ctx.post_backward_function(ctx.module, grads) - if ret is not None: - updated_grads = ret - # TODO(pengwa) Update grad for partitioned parameters. input_count = len(updated_grads) - len(ctx.partitioned_params) - zeros = [torch.zeros(0, dtype=p.dtype, device=p.device) for p in ctx.partitioned_params] - zero_grads = updated_grads[:input_count] + tuple(zeros) - - return (None, None, None, None, None, None, None, *zero_grads) + param_start_offset = input_count + + # Only need to accumulate grad explicitly for ORT run (e.g. ctx.shapes[0] == (1,)); + # In the PyTorch run, the accumulation happens automatically. + need_manual_grad_acc = len(ctx.shapes) > 0 and ctx.shapes[0] == (1,) + if need_manual_grad_acc: + for param_index, p in enumerate(ctx.partitioned_params): + g = updated_grads[param_index + param_start_offset] + if g is None: + raise RuntimeError(f"param {p} has no grad, this should not happen.") + # Param gradient accumulation is triggered here, along with the attached hooks, done by PyTorch. + assert p.shape == g.shape, f"param_index: {param_index} - param shape {p.shape} != grad shape {g.shape}" + p.backward(updated_grads[param_index + param_start_offset]) + + # At this point, the **real** param grads are already updated, the following grads are only used for + # completing the full backward propagation, will not affect parameter updates. + passed_in_param_grad = [ + torch.zeros(shape, dtype=dtype, device=device) + for shape, dtype, device in zip(ctx.shapes, ctx.dtypes, ctx.devices) + ] + + # nones = (None, ) * len(ctx.partitioned_params) + zero_grads = updated_grads[:input_count] + tuple(passed_in_param_grad) + + return (None, None, None, None, None, None, *zero_grads) @staticmethod def infer_shape( @@ -258,14 +282,14 @@ def forward( module: the module to be called post_forward_function: the function to be called after forward (PyTorch's post_forward_function) pre_backward_function: the function to be called before backward (PyTorch's pre_backward_function) - output_schema: the schema of the output, used to reconstruct the output in original form in + output_schema: the schema of the output, used to reconstruct the output in its original form in PyTorch's post_forward_function's inputs. output_tensors: the list of tensors. """ outputs = unflatten_data_using_schema(output_tensors, output_schema) - # STAGE3WARN: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. + # STAGE3WARN#3: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. updated_outputs = post_forward_function(module, None, outputs) if updated_outputs is None: @@ -341,11 +365,19 @@ def pre_forward_module_apply_impl( input and output for torch.autograd.Function, so we do flatten and unflatten here. """ + ## Handle `_post_backward_module_hook` - args_tensors, args_schema = extract_data_and_schema(args) - kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs) + # Put `_post_backward_module_hook` first because in backward, it is responsible for unloading parameters, + # we want ORTZeROOffloadPreForwardFunction's backward still be able to access the full sized parameters. + _post_backward_module_hook = self._functions.get("_post_backward_module_hook") + # STAGE3WARN#4: most logic in _post_backward_module_hook can be traced correctly so we don't need to + # wrap with PythonOp. For those cannot be traced, we handle them in STAGE3WARN#5. + updated_args = _post_backward_module_hook(module, args) - partitioned_params = _get_params_for_current_module(module) + ## Handle `_pre_forward_module_hook` + + args_tensors, args_schema = extract_data_and_schema(updated_args) + kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs) _pre_forward_module_hook = self._functions.get("_pre_forward_module_hook") @@ -358,18 +390,26 @@ def _wrap_pre_forward_module_hook(module, args, kwargs): if rets is not None: updated_args = rets - # STAGE3WARN: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. + # STAGE3WARN#5: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. module.ds_grads_remaining = 0 + return updated_args, updated_kwargs - all_tensors = args_tensors + kwargs_tensors + partitioned_params + # Need to pass the parameters as input to let the exporter trace the related weights for + # current ORTZeROOffloadPreForwardFunction + partitioned_params = _get_params_for_current_module(module) + # Don't require grad for passed-in parameter, otherwise it will be treated as a leaf node, in backward + # returned 0-sized grad did not match the param's gradient accumulator function's input shape metadata, + # PyTorch run will fail during backward. + detached_partitioned_params = [p.detach().requires_grad_(False) for p in partitioned_params] + + all_tensors = args_tensors + kwargs_tensors + detached_partitioned_params self._check_all_tensor(all_tensors, module, "pre_forward_module_apply_impl input check") rets = ORTZeROOffloadPreForwardFunction.apply( module, _wrap_pre_forward_module_hook, - None, args_schema, kwargs_schema, args_tensor_count, @@ -385,11 +425,6 @@ def _wrap_pre_forward_module_hook(module, args, kwargs): updated_args = unflatten_data_using_schema(updated_args_tensors, args_schema) updated_kwargs = unflatten_data_using_schema(updated_kwargs_tensors, kwargs_schema) - _post_backward_module_hook = self._functions.get("_post_backward_module_hook") - # STAGE3WARN: Other part of _post_backward_module_hook can be traced correctly so we don't need to - # wrap with PythonOp. - updated_args = _post_backward_module_hook(module, updated_args) - return updated_args, updated_kwargs def post_forward_module_apply_impl( @@ -411,7 +446,7 @@ def post_forward_module_apply_impl( _post_forward_module_hook = self._functions.get("_post_forward_module_hook") def _wrap_post_forward_module_hook(module, input, outputs): - # STAGE3WARN: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. + # STAGE3WARN#6: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. from deepspeed.runtime.zero.partition_parameters import is_zero_param updated_outputs = _post_forward_module_hook(module, input, outputs) @@ -438,8 +473,8 @@ def _wrap_post_forward_module_hook(module, input, outputs): updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema) _pre_backward_module_hook = self._functions.get("_pre_backward_module_hook") - # STAGE3WARN: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here. - # STAGE3WARN: part of the original _pre_backward_module_hook can be traced correctly so we moved them into + # STAGE3WARN#7: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here. + # STAGE3WARN#8: part of the original _pre_backward_module_hook can be traced correctly so we moved them into # _wrap_post_forward_module_hook above. updated_outputs = _pre_backward_module_hook(module, None, updated_outputs) From 6aa0b4d59752c6b9a073ee1cd57135dc8d4011a3 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 25 Aug 2023 13:23:23 +0000 Subject: [PATCH 3/9] fixes --- .../ortmodule/_graph_execution_manager.py | 2 +- .../ortmodule/_zero_stage3_compatibility.py | 134 +++++++++--------- 2 files changed, 70 insertions(+), 66 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 5cdea97eea202..f1572c8c40b38 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -140,7 +140,7 @@ def __init__( register_triton_op_executor() - self._zero_stage3_param_map = None + self._zero_stage3_param_map = {} if self._runtime_options.enable_zero_stage3_support: # Cannot toggle feature enabling/disabling after the first time enabled. from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index 7c6116da439fa..aa37f1fde2186 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -19,22 +19,22 @@ def post_processing_enable_zero_stage3_compat( exported_model: ModelProto, - offload_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]], - param_names: List[str], + zero_stage3_named_params: Dict[str, torch.nn.parameter.Parameter], + all_param_names: List[str], ) -> ModelProto: """This function is used to enable zero stage3 compatibility. Args: exported_model (ModelProto): The exported model. - offload_named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The offload named parameters. - param_names (List[str]): All parameter names. + zero_stage3_named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The offload named parameters. + all_param_names (List[str]): All parameter names. """ # Register symbolic shape inference functions for PythonOp used in DeepSpeed ZeRO stage3. _zero_stage3_register_symbolic_functions() - # Create weight retrieving function using offload_named_params. - func_full_qual_name = _create_weight_retrieval_function(offload_named_params) + # Create weight retrieving function using zero_stage3_named_params. + func_full_qual_name = _create_weight_retrieval_function(zero_stage3_named_params) consumer_map = {} for node in exported_model.graph.node: @@ -59,17 +59,17 @@ def _get_func_name(node: NodeProto) -> Optional[str]: # Create weight retrieving PythonOp. new_input, weight_pull_node = _create_weight_retrieval_pythonop( - offload_named_params, + zero_stage3_named_params, func_full_qual_name, STAGE3_PULL_WEIGHT_TRIGGER_NAME, - [_get_param_pull_trigger_name(pname) for pname in offload_named_params], + [_get_param_pull_trigger_name(pname) for pname in zero_stage3_named_params], trigger_data_type, trigger_data_dims, ) # Connect weight consumers to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction. for graph_input in exported_model.graph.input: - if graph_input.name not in offload_named_params: + if graph_input.name not in zero_stage3_named_params: continue if graph_input.name not in consumer_map: @@ -78,16 +78,11 @@ def _get_func_name(node: NodeProto) -> Optional[str]: consumers = consumer_map[graph_input.name] pre_forward_pythonop_node = None - input_tensor_ranks = [] - input_tensor_dtypes = [] - rank_attr = None - dtype_attr = None for c in consumers: if c.op_type != "PythonOp": continue func_name = _get_func_name(c) - if ( func_name == "onnxruntime.training.utils.hooks._zero_offload_subscriber.ORTZeROOffloadPreForwardFunction" @@ -96,13 +91,6 @@ def _get_func_name(node: NodeProto) -> Optional[str]: pre_forward_pythonop_node is None ), "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen" pre_forward_pythonop_node = c - for attr in c.attribute: - if attr.name == "input_tensor_ranks": - input_tensor_ranks = attr.ints - rank_attr = attr - if attr.name == "input_tensor_types": - input_tensor_dtypes = attr.ints - dtype_attr = attr if pre_forward_pythonop_node is None: raise RuntimeError( @@ -119,18 +107,16 @@ def _get_func_name(node: NodeProto) -> Optional[str]: ), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input}" reverse_index_among_inputs = index_offset_on_python_op_input[0] - len(pre_forward_pythonop_node.input) - pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = _get_param_pull_trigger_name( - graph_input.name + new_input_name = _get_param_pull_trigger_name(graph_input.name) + pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = new_input_name + + _update_python_op_input_related_attributes( + pre_forward_pythonop_node, + new_input_name, + len(trigger_data_dims), # new rank + trigger_data_type, # new data type ) - # For PythonOp, we need to update some of its attributes. - input_tensor_ranks[index_offset_on_python_op_input[0]] = len(trigger_data_dims) - input_tensor_dtypes[index_offset_on_python_op_input[0]] = trigger_data_type - pre_forward_pythonop_node.attribute.remove(rank_attr) - pre_forward_pythonop_node.attribute.remove(dtype_attr) - pre_forward_pythonop_node.attribute.append(helper.make_attribute("input_tensor_ranks", input_tensor_ranks)) - pre_forward_pythonop_node.attribute.append(helper.make_attribute("input_tensor_types", input_tensor_dtypes)) - output_index = reverse_index_among_inputs + len(pre_forward_pythonop_node.output) pre_forward_pythonop_node.output[output_index] = graph_input.name @@ -140,36 +126,16 @@ def _get_func_name(node: NodeProto) -> Optional[str]: for c in consumers: if c == pre_forward_pythonop_node or c.op_type != "PythonOp": continue - - input_tensor_ranks = [] - input_tensor_dtypes = [] - rank_attr = None - dtype_attr = None - for attr in c.attribute: - if attr.name == "input_tensor_ranks": - input_tensor_ranks = attr.ints - rank_attr = attr - if attr.name == "input_tensor_types": - input_tensor_dtypes = attr.ints - dtype_attr = attr - - index_offset_on_python_op_input = [] - for i, input_name in enumerate(c.input): - if input_name == graph_input.name: - index_offset_on_python_op_input.append(i) - - for index in index_offset_on_python_op_input: - input_tensor_ranks[index] = len(offload_named_params[graph_input.name].ds_shape) - input_tensor_dtypes[index] = pytorch_dtype_to_onnx(offload_named_params[graph_input.name].dtype) - - c.attribute.remove(rank_attr) - c.attribute.remove(dtype_attr) - c.attribute.append(helper.make_attribute("input_tensor_ranks", input_tensor_ranks)) - c.attribute.append(helper.make_attribute("input_tensor_types", input_tensor_dtypes)) + _update_python_op_input_related_attributes( + c, + graph_input.name, + len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank + pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type + ) # Delete exported_model.graph.input graph_inputs_to_remove = [ - graph_input for graph_input in exported_model.graph.input if graph_input.name in offload_named_params + graph_input for graph_input in exported_model.graph.input if graph_input.name in zero_stage3_named_params ] for input_to_remove in graph_inputs_to_remove: exported_model.graph.input.remove(input_to_remove) @@ -177,7 +143,7 @@ def _get_func_name(node: NodeProto) -> Optional[str]: # Re-order graph input to make sure the weight pull trigger is before all parameter inputs. offset = 0 for graph_input in exported_model.graph.input: - if graph_input.name in param_names: + if graph_input.name in all_param_names: break offset += 1 @@ -187,13 +153,15 @@ def _get_func_name(node: NodeProto) -> Optional[str]: return exported_model -def _create_weight_retrieval_function(offload_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]]) -> str: - """This function is used to create a weight retrieving function using offload_named_params.""" +def _create_weight_retrieval_function( + zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]] +) -> str: + """This function is used to create a weight retrieving function using zero_stage3_named_params.""" class WeightRetrievalFunction(torch.autograd.Function): @staticmethod def forward(ctx, weight_in_trigger): - params = list(offload_named_params.values()) + params = list(zero_stage3_named_params.values()) ctx.params = params ctx.dtype = weight_in_trigger.dtype ctx.device = weight_in_trigger.device @@ -210,7 +178,7 @@ def infer_shape( tensor_input_shapes: List[Optional[List[Union[int, str]]]], tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: - param_count = len(offload_named_params.values()) + param_count = len(zero_stage3_named_params.values()) tensor_output_shapes = [ tensor_input_shapes[0], ] * param_count @@ -262,7 +230,7 @@ def _linear_infer_shape( def _create_weight_retrieval_pythonop( - offload_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]], + zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]], func_full_qual_name: str, input_name: str, output_names: List[str], @@ -270,7 +238,7 @@ def _create_weight_retrieval_pythonop( trigger_data_dims: List[int], ) -> Tuple[ValueInfoProto, NodeProto]: """This function is used to create a weight retrieving PythonOp.""" - offload_param_count = 0 if offload_named_params is None else len(offload_named_params) + offload_param_count = 0 if zero_stage3_named_params is None else len(zero_stage3_named_params) new_input = helper.make_tensor_value_info(input_name, trigger_data_type, trigger_data_dims) output_rank_for_pull_weight_trigger = len(trigger_data_dims) output_dtype_for_pull_weight_trigger = trigger_data_type @@ -304,3 +272,39 @@ def _create_weight_retrieval_pythonop( ) return new_input, weight_pull_node + + +def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, new_rank: int, new_dtype: int): + """This function is used to update PythonOp's input related attributes, e.g. + input_tensor_ranks and input_tensor_types. + + Args: + node (NodeProto): The PythonOp node. + input_name (str): The input name to be updated. + new_rank (int): The new rank of the input, to be used in input_tensor_ranks. + new_dtype (int): The new data type of the input, to be used in input_tensor_types. + """ + input_tensor_ranks = None + input_tensor_dtypes = None + rank_attr = None + dtype_attr = None + for attr in node.attribute: + if attr.name == "input_tensor_ranks": + input_tensor_ranks = attr.ints + rank_attr = attr + if attr.name == "input_tensor_types": + input_tensor_dtypes = attr.ints + dtype_attr = attr + + assert input_tensor_ranks is not None, "input_tensor_ranks is None" + assert input_tensor_dtypes is not None, "input_tensor_dtypes is None" + + for index, node_input_name in enumerate(node.input): + if node_input_name == input_name: + input_tensor_ranks[index] = new_rank + input_tensor_dtypes[index] = new_dtype + + node.attribute.remove(rank_attr) + node.attribute.remove(dtype_attr) + node.attribute.append(helper.make_attribute("input_tensor_ranks", input_tensor_ranks)) + node.attribute.append(helper.make_attribute("input_tensor_types", input_tensor_dtypes)) From 08e8302123bdaac582a6dd50e673c23d342059de Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 25 Aug 2023 13:30:09 +0000 Subject: [PATCH 4/9] refinement --- .../ortmodule/_graph_execution_manager.py | 10 ++++-- .../ortmodule/_zero_stage3_compatibility.py | 33 ++++++++++--------- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index f1572c8c40b38..0a341164a9642 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -618,9 +618,15 @@ def _enable_conditional_optimizations( self._runtime_inspector.enable_memory_inspector(self._original_module) def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device): - from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME + from ._zero_stage3_compatibility import ( + STAGE3_PULL_WEIGHT_TRIGGER_NAME, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + ) - kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros([1], dtype=torch.float32, device=device).requires_grad_() + kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros( + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, dtype=STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, device=device + ).requires_grad_() return kwargs diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index aa37f1fde2186..3fc3e2cf8f4fc 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -15,6 +15,8 @@ from ._utils import get_fully_qualified_class_name STAGE3_PULL_WEIGHT_TRIGGER_NAME = "pull_weight_trigger" +STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE = TensorProto.FLOAT +STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE = [1] def post_processing_enable_zero_stage3_compat( @@ -31,7 +33,7 @@ def post_processing_enable_zero_stage3_compat( """ # Register symbolic shape inference functions for PythonOp used in DeepSpeed ZeRO stage3. - _zero_stage3_register_symbolic_functions() + _register_symbolic_shape_infer_functions() # Create weight retrieving function using zero_stage3_named_params. func_full_qual_name = _create_weight_retrieval_function(zero_stage3_named_params) @@ -54,17 +56,14 @@ def _get_func_name(node: NodeProto) -> Optional[str]: return attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s return None - trigger_data_type = TensorProto.FLOAT - trigger_data_dims = [1] - # Create weight retrieving PythonOp. new_input, weight_pull_node = _create_weight_retrieval_pythonop( zero_stage3_named_params, func_full_qual_name, STAGE3_PULL_WEIGHT_TRIGGER_NAME, [_get_param_pull_trigger_name(pname) for pname in zero_stage3_named_params], - trigger_data_type, - trigger_data_dims, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, ) # Connect weight consumers to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction. @@ -113,8 +112,8 @@ def _get_func_name(node: NodeProto) -> Optional[str]: _update_python_op_input_related_attributes( pre_forward_pythonop_node, new_input_name, - len(trigger_data_dims), # new rank - trigger_data_type, # new data type + len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE), # new rank + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, # new data type ) output_index = reverse_index_among_inputs + len(pre_forward_pythonop_node.output) @@ -194,7 +193,7 @@ def infer_shape( return func_full_qual_name -def _zero_stage3_register_symbolic_functions(): +def _register_symbolic_shape_infer_functions(): """This function is used to register symbolic shape inference functions for PythonOp used in DeepSpeed ZeRO stage3.""" @@ -234,14 +233,16 @@ def _create_weight_retrieval_pythonop( func_full_qual_name: str, input_name: str, output_names: List[str], - trigger_data_type, - trigger_data_dims: List[int], + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE: List[int], ) -> Tuple[ValueInfoProto, NodeProto]: """This function is used to create a weight retrieving PythonOp.""" offload_param_count = 0 if zero_stage3_named_params is None else len(zero_stage3_named_params) - new_input = helper.make_tensor_value_info(input_name, trigger_data_type, trigger_data_dims) - output_rank_for_pull_weight_trigger = len(trigger_data_dims) - output_dtype_for_pull_weight_trigger = trigger_data_type + new_input = helper.make_tensor_value_info( + input_name, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE + ) + output_rank_for_pull_weight_trigger = len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE) + output_dtype_for_pull_weight_trigger = STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE output_tensor_ranks = [ output_rank_for_pull_weight_trigger, ] * offload_param_count @@ -253,8 +254,8 @@ def _create_weight_retrieval_pythonop( "comment": "", "inplace": 0, "input_convention": "d", - "input_tensor_ranks": [len(trigger_data_dims)], - "input_tensor_types": [trigger_data_type], + "input_tensor_ranks": [len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE)], + "input_tensor_types": [STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE], "output_tensor_ranks": output_tensor_ranks, "output_tensor_types": output_tensor_types, "training_mode": 1, From a25856d0b42b5825ee35a3a8bb5bc099e6667126 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 25 Aug 2023 13:44:07 +0000 Subject: [PATCH 5/9] minors --- .../training/utils/hooks/_zero_offload_subscriber.py | 10 ++++++---- .../cpu/torch/torch_custom_function_kernel_base.cc | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index 5fb72216f33f0..3689037738475 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -167,9 +167,9 @@ def forward( # For PyTorch runs, the sizes are all 0, it does not need a gradient because # param._detach().requires_grad_(False) is called. # But for ORT runs, the sizes are all [1], as output of weight retrieval function. - # So we keep track of the shapes and dtypes of the passed in tensors, then generate the grads in backward. + # So we keep track of the shapes and dtypes of the passed-in tensors, then generate the grads in backward. # While for both PyTorch and ORT runs, the grad is not important because they are not param grads - # any more, they are only used for completing the full backward propagation. + # anymore, they are only used for completing the full backward propagation. passed_in_param_tensors = tensor_list[args_tensor_count + kwargs_tensor_count :] ctx.shapes = [p.shape for p in passed_in_param_tensors] ctx.dtypes = [p.dtype for p in passed_in_param_tensors] @@ -225,7 +225,7 @@ def backward(ctx, *grads): raise RuntimeError(f"param {p} has no grad, this should not happen.") # Param gradient accumulation is triggered here, along with the attached hooks, done by PyTorch. assert p.shape == g.shape, f"param_index: {param_index} - param shape {p.shape} != grad shape {g.shape}" - p.backward(updated_grads[param_index + param_start_offset]) + p.backward(g) # At this point, the **real** param grads are already updated, the following grads are only used for # completing the full backward propagation, will not affect parameter updates. @@ -234,7 +234,6 @@ def backward(ctx, *grads): for shape, dtype, device in zip(ctx.shapes, ctx.dtypes, ctx.devices) ] - # nones = (None, ) * len(ctx.partitioned_params) zero_grads = updated_grads[:input_count] + tuple(passed_in_param_grad) return (None, None, None, None, None, None, *zero_grads) @@ -401,6 +400,9 @@ def _wrap_pre_forward_module_hook(module, args, kwargs): # Don't require grad for passed-in parameter, otherwise it will be treated as a leaf node, in backward # returned 0-sized grad did not match the param's gradient accumulator function's input shape metadata, # PyTorch run will fail during backward. + # This will not harm parameter gradient build either in ORT or PyTorch, imagine the weights are used by + # computation anyway, so the gradient will be built. This hook only references the parameter, but won't + # generate a gradient path for it. detached_partitioned_params = [p.detach().requires_grad_(False) for p in partitioned_params] all_tensors = args_tensors + kwargs_tensors + detached_partitioned_params diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index 0a3f8139eb020..e1d4be24861f5 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -156,7 +156,8 @@ void PythonOpBase::RunForward(OpKernelContext* context, const size_t returned_output_count = 1 + returned_ortvalues.size(); const size_t kernel_output_count = static_cast(context->OutputCount()); ORT_ENFORCE(returned_output_count == kernel_output_count, "Output count mismatch for PythonOp run, ", - "returned_output_count: ", returned_output_count, ", expected kernel_output_count: ", kernel_output_count); + "returned_output_count: ", returned_output_count, ", expected kernel_output_count: ", + kernel_output_count); } void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx, std::vector& returned_args) const { From 5e8330a396869562ff79409a65d9534979372f1f Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Wed, 30 Aug 2023 08:28:02 +0000 Subject: [PATCH 6/9] add convergence debug switches --- .../_custom_autograd_function_exporter.py | 6 +- .../ortmodule/_graph_execution_manager.py | 8 +- .../ortmodule/_zero_stage3_compatibility.py | 9 +- .../python/training/utils/__init__.py | 3 +- .../utils/hooks/_statistics_subscriber.py | 171 ++++++++++-------- .../utils/hooks/_subscriber_manager.py | 8 - .../utils/hooks/_zero_offload_subscriber.py | 52 ++++-- .../python/training/utils/torch_type_map.py | 9 + 8 files changed, 148 insertions(+), 118 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 31bf1c60d0515..f75d553a5f460 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -234,9 +234,9 @@ def _export_pt_1_10(g, n, *args, **kwargs): input_float_tuples.extend(list(arg)) continue - is_inspect_activation = ( - func_full_qual_name == "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation" - ) + from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation + + is_inspect_activation = func_full_qual_name == get_fully_qualified_class_name(_InspectActivation) if is_inspect_activation and isinstance(arg, str): # _InspectActivation is a special case where the first argument is a string # that is used to determine the activation name to be inspected. diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 0a341164a9642..6555d64833158 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -19,7 +19,7 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType +from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils @@ -147,7 +147,7 @@ def __init__( self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module) - configure_ort_compatible_zero_stage3() + configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): @@ -625,7 +625,9 @@ def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.devic ) kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros( - STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, dtype=STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, device=device + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), + device=device, ).requires_grad_() return kwargs diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index 3fc3e2cf8f4fc..17756600d601e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -66,6 +66,10 @@ def _get_func_name(node: NodeProto) -> Optional[str]: STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, ) + from onnxruntime.training.utils.hooks._zero_offload_subscriber import ORTZeROOffloadPreForwardFunction + + prefowrad_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction) + # Connect weight consumers to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction. for graph_input in exported_model.graph.input: if graph_input.name not in zero_stage3_named_params: @@ -82,10 +86,7 @@ def _get_func_name(node: NodeProto) -> Optional[str]: continue func_name = _get_func_name(c) - if ( - func_name - == "onnxruntime.training.utils.hooks._zero_offload_subscriber.ORTZeROOffloadPreForwardFunction" - ): + if func_name == prefowrad_function_name: assert ( pre_forward_pythonop_node is None ), "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen" diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index acf2698d55eaf..fa7c9f2750cdd 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -9,7 +9,7 @@ extract_data_and_schema, unflatten_data_using_schema, ) -from onnxruntime.training.utils.torch_type_map import pytorch_dtype_to_onnx +from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx __all__ = [ "PrimitiveType", @@ -18,4 +18,5 @@ "extract_data_and_schema", "unflatten_data_using_schema", "pytorch_dtype_to_onnx", + "onnx_dtype_to_pytorch", ] diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index 6c8027b2fefaa..db1c69cf95ba4 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -6,6 +6,7 @@ import os import shutil import warnings +from io import TextIOWrapper from pathlib import Path from typing import List, Optional, Tuple, Union @@ -178,87 +179,97 @@ def _summarize_activations(self, tensor: torch.Tensor, depth: int, name: str, st order_file_path = step_path / "order.txt" tensor_file_path = step_path / output_file_name - # This is to try the best effort to align the count of numbers per line for easier comparison in diff views, - # though it does not always guarantee to do this way. - torch.set_printoptions(precision=6, linewidth=128) - - tensor_shape = tensor.shape - tensor_dtype = tensor.dtype - flatten_array = tensor.flatten().view(-1) - - if self._run_on_cpu: - flatten_array = flatten_array.to("cpu") - - if self._run_on_cpu: - num_nan = torch.isnan(flatten_array).sum() - num_inf = torch.isinf(flatten_array).sum() - num_neg = (flatten_array < 0).sum() - num_pos = (flatten_array > 0).sum() - num_zero = (flatten_array == 0).sum() - min_value = flatten_array.min() - max_value = flatten_array.max() - mean_value = flatten_array.mean() - std_value = flatten_array.std() - else: - # Split the calculation for each bucket, then do another round of calculation on the bucket results. - # This can at the best effort reduce the peak memory impact. - bucket_size = self._bucket_size - element_count = flatten_array.numel() - ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size) - nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - - # Summary for each bucket - element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - for i in range(ceil_bucket_count): - end = min((i + 1) * bucket_size, element_count) - bucket = flatten_array[i * bucket_size : end] - element_count_per_bucket[i] = bucket.numel() - - nan_buckets[i] = torch.isnan(bucket).sum() - inf_buckets[i] = torch.isinf(bucket).sum() - neg_buckets[i] = (bucket < 0).sum() - pos_buckets[i] = (bucket > 0).sum() - zero_buckets[i] = (bucket == 0).sum() - min_buckets[i] = bucket.min() - max_buckets[i] = bucket.max() - mean_buckets[i] = bucket.sum() - std_buckets[i] = bucket.std() - - # Reduction across all buckets - num_nan = nan_buckets.sum() - num_inf = inf_buckets.sum() - num_neg = neg_buckets.sum() - num_pos = pos_buckets.sum() - num_zero = zero_buckets.sum() - min_value = min_buckets.min() - max_value = max_buckets.max() - mean_value = float(mean_buckets.sum()) / float(element_count) - # Here we refer to - # https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups - # to calculate the combined standard deviation of all buckets. - s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * ( - (mean_buckets - mean_value) ** 2 - ) - std_value = torch.sqrt(s.sum() / (element_count - 1)) - with order_file_path.open(mode="a", encoding="utf-8") as f: f.write(f"{output_file_name}\n") with tensor_file_path.open(mode="w", encoding="utf-8") as f: - f.write( - f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n" - f"min: {min_value} max: {max_value}, mean: {mean_value}, " - f"std: {std_value} \n" - f"nan: {num_nan}, inf: {num_inf}\n" - ) - f.write(f"samples(top 128): {flatten_array[:128]}\n") - f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n") - f.write(f"{'='*16}\n") + _summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size) + + +def _summarize_tensor( + display_name: str, + tensor: torch.Tensor, + f: TextIOWrapper, + depth: int = 0, + run_on_cpu: bool = False, + bucket_size: int = 1024 * 1024 * 1024 // 2, +): + # This is to try the best effort to align the count of numbers per line for easier comparison in diff views, + # though it does not always guarantee to do this way. + torch.set_printoptions(precision=6, linewidth=128) + + tensor_shape = tensor.shape + tensor_dtype = tensor.dtype + flatten_array = tensor.flatten().view(-1) + + if run_on_cpu: + flatten_array = flatten_array.to("cpu") + + if run_on_cpu: + num_nan = torch.isnan(flatten_array).sum() + num_inf = torch.isinf(flatten_array).sum() + num_neg = (flatten_array < 0).sum() + num_pos = (flatten_array > 0).sum() + num_zero = (flatten_array == 0).sum() + min_value = flatten_array.min() + max_value = flatten_array.max() + mean_value = flatten_array.mean() + std_value = flatten_array.std() + else: + # Split the calculation for each bucket, then do another round of calculation on the bucket results. + # This can at the best effort reduce the peak memory impact. + element_count = flatten_array.numel() + ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size) + nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + + # Summary for each bucket + element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + for i in range(ceil_bucket_count): + end = min((i + 1) * bucket_size, element_count) + bucket = flatten_array[i * bucket_size : end] + element_count_per_bucket[i] = bucket.numel() + + nan_buckets[i] = torch.isnan(bucket).sum() + inf_buckets[i] = torch.isinf(bucket).sum() + neg_buckets[i] = (bucket < 0).sum() + pos_buckets[i] = (bucket > 0).sum() + zero_buckets[i] = (bucket == 0).sum() + min_buckets[i] = bucket.min() + max_buckets[i] = bucket.max() + mean_buckets[i] = bucket.sum() + std_buckets[i] = bucket.std() + + # Reduction across all buckets + num_nan = nan_buckets.sum() + num_inf = inf_buckets.sum() + num_neg = neg_buckets.sum() + num_pos = pos_buckets.sum() + num_zero = zero_buckets.sum() + min_value = min_buckets.min() + max_value = max_buckets.max() + mean_value = float(mean_buckets.sum()) / float(element_count) + # Here we refer to + # https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups + # to calculate the combined standard deviation of all buckets. + s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * ( + (mean_buckets - mean_value) ** 2 + ) + std_value = torch.sqrt(s.sum() / (element_count - 1)) + + f.write( + f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n" + f"min: {min_value} max: {max_value}, mean: {mean_value}, " + f"std: {std_value} \n" + f"nan: {num_nan}, inf: {num_inf}\n" + ) + f.write(f"samples(top 128): {flatten_array[:128]}\n") + f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n") + f.write(f"{'='*16}\n") diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index 5814448960091..6c851035dd554 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -29,14 +29,6 @@ def no_increase_global_step(): finally: ORT_NO_INCREASE_GLOBAL_STEP[0] = False - @staticmethod - def infer_shape( - node: onnx.NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: - return tensor_input_shapes, tensor_input_dtypes - class _IncrementStep(torch.autograd.Function): """This class is used to manage the global execution step, e.g. diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index 3689037738475..ad1297962db71 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -23,25 +23,37 @@ from ._subscriber_base import RuntimeStates, SubscriberBase -# Used to monkey patch the original function -# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333 -def _setup_zero_stage3_ort_compatible_hooks(self): - self.hierarchy = 0 +def _get_ort_compatible_zero_stage3_hook_function(debug, stats_output_dir, stats_overwrite): + """Create ort compatible hook function for DeepSpeed ZeRO stage3. - from onnxruntime.training.utils.hooks import SubscriberManager, ZeROOffloadSubscriber - from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer + Args: + debug: whether to enable convergence debugging. + stats_output_dir: the directory to store convergence stats. + stats_overwrite: whether to overwrite the stats file if it already exists. + """ + + # Used to monkey patch the original function + # Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333 + def _setup_zero_stage3_ort_compatible_hooks(self): + self.hierarchy = 0 + + from onnxruntime.training.utils.hooks import StatisticsSubscriber, SubscriberManager, ZeROOffloadSubscriber + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer - # Each DeepSpeed engine has a separate subscriber manager. - self._offload_subscriber_manager = SubscriberManager() - self._offload_subscriber_manager.subscribe( - self.module, [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)] - ) - self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks) - self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks) + subscribers = [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)] + if debug is True: + subscribers.append(StatisticsSubscriber(output_dir=stats_output_dir, override_output_dir=stats_overwrite)) + # Each DeepSpeed engine has a separate subscriber manager. + self._offload_subscriber_manager = SubscriberManager() + self._offload_subscriber_manager.subscribe(self.module, subscribers) + self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks) + self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks) - # Add top module to stack trace - global FWD_MODULE_STACK # noqa: PLW0602 - FWD_MODULE_STACK.append(self.module) + # Add top module to stack trace + global FWD_MODULE_STACK # noqa: PLW0602 + FWD_MODULE_STACK.append(self.module) + + return _setup_zero_stage3_ort_compatible_hooks # Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/linear.py#L104 @@ -86,14 +98,16 @@ def collect_code(self, function: Callable): _zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload.setup_zero_stage3_hooks) # This is the function to enable ORT ZeRO offload. - def configure_ort_compatible_zero_stage3(): + def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="./", stats_overwrite=False): """Configure ZeRO stage3 to be ORT compatible. This function will overwrite the original DeepSpeed ZeRO stage3 hooks to make it ORT compatible. """ # Only done once no matter how many times this function is called for different modules. - DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _setup_zero_stage3_ort_compatible_hooks + DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _get_ort_compatible_zero_stage3_hook_function( + debug, stats_output_dir, stats_overwrite + ) from deepspeed.runtime.zero.linear import zero3_linear_wrap @@ -103,7 +117,7 @@ def configure_ort_compatible_zero_stage3(): except ImportError as e: warnings.warn(f"DeepSpeed import error {e}") - def configure_ort_compatible_zero_stage3(): + def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, stats_overwrite=False): raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.") diff --git a/orttraining/orttraining/python/training/utils/torch_type_map.py b/orttraining/orttraining/python/training/utils/torch_type_map.py index 699747723f457..bdacab8ad04fe 100644 --- a/orttraining/orttraining/python/training/utils/torch_type_map.py +++ b/orttraining/orttraining/python/training/utils/torch_type_map.py @@ -33,6 +33,8 @@ _DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _CAST_PYTORCH_TO_ONNX.items()} +_ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()} + def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: """Converts a pytorch dtype or scalar type string to an onnx dtype.""" @@ -45,3 +47,10 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc if dtype not in _DTYPE_TO_ONNX: raise RuntimeError(f"Unsupported dtype {dtype}") return _DTYPE_TO_ONNX[dtype] + + +def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype: + """Converts an onnx dtype to a pytorch dtype.""" + if dtype not in _ONNX_TO_DTYPE: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _ONNX_TO_DTYPE[dtype] From 57300d97202de5e3da67bd86398e7baf03854170 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Wed, 30 Aug 2023 14:57:46 +0000 Subject: [PATCH 7/9] comment debug log --- .../python/training/utils/hooks/_subscriber_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index 6c851035dd554..b2bc64be42fc1 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -47,8 +47,9 @@ def forward(ctx, run_ctx: RuntimeStates, *input_tensor_list: Tuple[torch.Tensor, ctx.current_step = run_ctx.global_states.execution_step ctx.run_ctx = run_ctx - if ctx.current_step >= 0: - print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") + # Uncomment the following line for debugging purposes. + # if ctx.current_step >= 0: + # print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") if ORT_NO_INCREASE_GLOBAL_STEP[0] is False: ctx.run_ctx.global_states.execution_step += 1 From f7941b2782207113ba9f21684efd38987dadb86f Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Mon, 18 Sep 2023 06:48:19 +0000 Subject: [PATCH 8/9] fix bug --- .../ortmodule/_custom_autograd_function_runner.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 845c7d83c2e7b..a5b96c4e37140 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -376,6 +376,16 @@ def wrap_all_outputs(result): result = backward_function(*wrapped_args) # Extract results as DLPack tensor list. + if isinstance(result, torch.Tensor): + result = [result] + elif isinstance(result, (tuple, list)): + result = list(result) + else: + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule does not support the following model output type {type(result)}."), + ) + wrapped_returned_args = wrap_all_outputs(result) torch_interop_utils.unregister_grad_fn(id(ctx)) From dd021960d8eee5fc1f10b194f06caa1ae47f0fe1 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Tue, 19 Sep 2023 05:39:00 +0000 Subject: [PATCH 9/9] refine one comment --- .../python/training/ortmodule/_graph_execution_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 6555d64833158..dfaac5f0fa836 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -428,8 +428,9 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu [name for name, _ in self._flattened_module.named_parameters()], ) - # Cannot append pull weight trigger name to input names here, otherwise, the later check find - # input info mismatch, will re-initialize the graph builder. + # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( + # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) + # find input info mismatch, will re-initialize the graph builder. # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) # Cache model for future runs