From 4bfa84487cc6fe992b18d69ccd5f0d54392b64f5 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 6 Dec 2023 04:41:17 +0800 Subject: [PATCH] Skip module clone for preparing large model export (#18663) ### Skip module clone for preparing large model export For LLAMA2 13B, when running with Lora, DeepSpeed stage2 on 8 GPUs . It failed during preparing outputs which will be used for torch.onnx.export. The reason, we deep copy all the params including both big sizes of frozen weights, + a little bit of Lora trainable weight. This PR will firstly check whether the GPU memmory is enough for a cloned module, if not, skip the copy. Copying the module is to guarantee the fw path run may change the weight, while this case should be rare. But for now, Not-Able-To-Run is worse than Runnable-with-A-little-bit-different-initial-weight, especially for large models. --- docs/ORTModule_Training_Guidelines.md | 11 +++++ .../ortmodule/_graph_execution_manager.py | 20 +++++++- .../python/training/ortmodule/_io.py | 46 +++++++++++++++++-- .../python/training/ortmodule/options.py | 5 ++ 4 files changed, 76 insertions(+), 6 deletions(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index d3ec61e86779b..a3cceb441a2a9 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -278,6 +278,17 @@ data sparsity based performance optimizations. export ORTMODULE_USE_EFFICIENT_ATTENTION=1 ``` +#### ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the module deep copy when preparing output data which will be used by ONNX export. +A classical usage of disabling the deep copy: when the deep copy before module export bring the memory peak, then we should disable it and have a try. + + ```bash + export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=1 # Enable + export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 # Disable + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 5696bfead7b51..dd6d5a568cb18 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -327,12 +327,30 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu # Setup dynamic axes for onnx model self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) + need_deep_copy = self._runtime_options.deepcopy_before_model_export and _io.can_module_be_deep_cloned( + self._original_module, self._device + ) + if not need_deep_copy: + if self._runtime_options.deepcopy_before_model_export: + self._logger.warning( + "Since the user requested not to deep copy this model, " + "the initial weights may not be preserved and could change slightly during the forward run. " + "This could cause a minor difference between the ORTModule and the PyTorch run for the " + "first iteration. The computation will proceed as normal, but this should be noted." + ) + else: + self._logger.warning( + "Due to the limited GPU memory execution manager does not create a deep copy of this model. " + "Therefore, the initial weights might be slightly altered during the forward run. " + "This could result in a minor discrepancy between the ORTModule and the PyTorch run for the " + "first iteration. The computation will continue as usual, but this should be noted." + ) ( output_names, output_dynamic_axes, self._module_output_schema, ) = _io.parse_outputs_for_onnx_export_and_extract_schema( - self._original_module, inputs, kwargs, self._logger, self._device + self._original_module, inputs, kwargs, self._logger, self._device, need_deep_copy ) self._input_info.dynamic_axes.update(output_dynamic_axes) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index f5fbd5093fca3..7534cc46a21f1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -543,25 +543,61 @@ def _add_input(name, input_value, onnx_graph, onnx_graph_input_names): ) +def calculate_total_parameter_size_in_bytes(module: torch.nn.Module) -> int: + """Calculate the total parameter size in bytes""" + total_size = 0 + for p in module.parameters(): + total_size += p.numel() * p.element_size() + return total_size + + +def can_module_be_deep_cloned(module: torch.nn.Module, device: Optional[torch.device]) -> bool: + """Check if the module can be cloned + + If the 2 times total module parameter size >= device memory, the module cannot be cloned. + > Initially there is one set of parameters; + > parse_outputs_for_onnx_export_and_extract_schema want to clone the full module including the frozen weight; + > PyTorch ONNX exporter will clone the trainable parameters; + + So as long as the module can be cloned in parse_outputs_for_onnx_export_and_extract_schema, it is safe + to export the model without OOM. Here we return whether can clone the module in + parse_outputs_for_onnx_export_and_extract_schema. + + Args: + module: The module to be cloned. + device: The device to be used for cloning. + """ + + if device is None or device.type != "cuda": + return True + + total_size = calculate_total_parameter_size_in_bytes(module) + return total_size * 2 < torch.cuda.get_device_properties(device).total_memory * 0.90 # give a 10% buffer + + def parse_outputs_for_onnx_export_and_extract_schema( module, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType], logger: Logger, device: Optional[torch.device], + clone_module: bool, ): # Perform a forward call to grab outputs output_names = None output_dynamic_axes = None - is_deepcopy = False + deep_copied = False logger.info("Running model forward to infer output schema and dynamic axes...") with torch.no_grad(): # Deepcopy inputs, since input values may change after model run. sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs) try: - # Deepcopy model, in case model is stateful and changes after model run. - model_copy = copy.deepcopy(module) - is_deepcopy = True + if clone_module: + # Deepcopy model, in case model is stateful and changes after model run. + model_copy = copy.deepcopy(module) + deep_copied = True + else: + model_copy = module except Exception: model_copy = module logger.warning( @@ -576,7 +612,7 @@ def parse_outputs_for_onnx_export_and_extract_schema( output_names, output_dynamic_axes = _parse_outputs_and_extract_names_and_dynamic_axes(sample_outputs) output_schema = _extract_schema(sample_outputs, device) - if is_deepcopy: + if deep_copied: del model_copy gc.collect() if torch.cuda.is_available(): diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 77022f86d3ff3..ffa3f4afa7b30 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -286,6 +286,8 @@ def __init__(self, logger: Logger): # Experimental features. self.enable_zero_stage3_support = False # Once enabled, cannot be disabled. + self.deepcopy_before_model_export = True + # Override the feature config if it exists in os env. self._override_from_env_vars() @@ -367,3 +369,6 @@ def _override_from_env_vars(self): # Experimental features. if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1: self.enable_zero_stage3_support = True + + if "ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT" in os.environ: + self.deepcopy_before_model_export = int(os.getenv("ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT")) == 1