Skip to content

Commit

Permalink
Skip module clone for preparing large model export (#18663)
Browse files Browse the repository at this point in the history
### Skip module clone for preparing large model export

For LLAMA2 13B, when running with Lora, DeepSpeed stage2 on 8 GPUs . It
failed during preparing outputs which will be used for
torch.onnx.export. The reason, we deep copy all the params including
both big sizes of frozen weights, + a little bit of Lora trainable
weight.

This PR will firstly check whether the GPU memmory is enough for a
cloned module, if not, skip the copy.

Copying the module is to guarantee the fw path run may change the
weight, while this case should be rare. But for now, Not-Able-To-Run is
worse than Runnable-with-A-little-bit-different-initial-weight,
especially for large models.
  • Loading branch information
pengwa authored Dec 5, 2023
1 parent 9aa7284 commit 4bfa844
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 6 deletions.
11 changes: 11 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,17 @@ data sparsity based performance optimizations.
export ORTMODULE_USE_EFFICIENT_ATTENTION=1
```

#### ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the module deep copy when preparing output data which will be used by ONNX export.
A classical usage of disabling the deep copy: when the deep copy before module export bring the memory peak, then we should disable it and have a try.

```bash
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=1 # Enable
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 # Disable
```

### 2.2 Memory Optimization

Q: *Want to run a bigger batch size?*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,30 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu

# Setup dynamic axes for onnx model
self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs)
need_deep_copy = self._runtime_options.deepcopy_before_model_export and _io.can_module_be_deep_cloned(
self._original_module, self._device
)
if not need_deep_copy:
if self._runtime_options.deepcopy_before_model_export:
self._logger.warning(
"Since the user requested not to deep copy this model, "
"the initial weights may not be preserved and could change slightly during the forward run. "
"This could cause a minor difference between the ORTModule and the PyTorch run for the "
"first iteration. The computation will proceed as normal, but this should be noted."
)
else:
self._logger.warning(
"Due to the limited GPU memory execution manager does not create a deep copy of this model. "
"Therefore, the initial weights might be slightly altered during the forward run. "
"This could result in a minor discrepancy between the ORTModule and the PyTorch run for the "
"first iteration. The computation will continue as usual, but this should be noted."
)
(
output_names,
output_dynamic_axes,
self._module_output_schema,
) = _io.parse_outputs_for_onnx_export_and_extract_schema(
self._original_module, inputs, kwargs, self._logger, self._device
self._original_module, inputs, kwargs, self._logger, self._device, need_deep_copy
)
self._input_info.dynamic_axes.update(output_dynamic_axes)

Expand Down
46 changes: 41 additions & 5 deletions orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,25 +543,61 @@ def _add_input(name, input_value, onnx_graph, onnx_graph_input_names):
)


def calculate_total_parameter_size_in_bytes(module: torch.nn.Module) -> int:
"""Calculate the total parameter size in bytes"""
total_size = 0
for p in module.parameters():
total_size += p.numel() * p.element_size()
return total_size


def can_module_be_deep_cloned(module: torch.nn.Module, device: Optional[torch.device]) -> bool:
"""Check if the module can be cloned
If the 2 times total module parameter size >= device memory, the module cannot be cloned.
> Initially there is one set of parameters;
> parse_outputs_for_onnx_export_and_extract_schema want to clone the full module including the frozen weight;
> PyTorch ONNX exporter will clone the trainable parameters;
So as long as the module can be cloned in parse_outputs_for_onnx_export_and_extract_schema, it is safe
to export the model without OOM. Here we return whether can clone the module in
parse_outputs_for_onnx_export_and_extract_schema.
Args:
module: The module to be cloned.
device: The device to be used for cloning.
"""

if device is None or device.type != "cuda":
return True

total_size = calculate_total_parameter_size_in_bytes(module)
return total_size * 2 < torch.cuda.get_device_properties(device).total_memory * 0.90 # give a 10% buffer


def parse_outputs_for_onnx_export_and_extract_schema(
module,
args: Sequence[ORTModelInputOutputType],
kwargs: Mapping[str, ORTModelInputOutputType],
logger: Logger,
device: Optional[torch.device],
clone_module: bool,
):
# Perform a forward call to grab outputs
output_names = None
output_dynamic_axes = None
is_deepcopy = False
deep_copied = False
logger.info("Running model forward to infer output schema and dynamic axes...")
with torch.no_grad():
# Deepcopy inputs, since input values may change after model run.
sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs)
try:
# Deepcopy model, in case model is stateful and changes after model run.
model_copy = copy.deepcopy(module)
is_deepcopy = True
if clone_module:
# Deepcopy model, in case model is stateful and changes after model run.
model_copy = copy.deepcopy(module)
deep_copied = True
else:
model_copy = module
except Exception:
model_copy = module
logger.warning(
Expand All @@ -576,7 +612,7 @@ def parse_outputs_for_onnx_export_and_extract_schema(
output_names, output_dynamic_axes = _parse_outputs_and_extract_names_and_dynamic_axes(sample_outputs)

output_schema = _extract_schema(sample_outputs, device)
if is_deepcopy:
if deep_copied:
del model_copy
gc.collect()
if torch.cuda.is_available():
Expand Down
5 changes: 5 additions & 0 deletions orttraining/orttraining/python/training/ortmodule/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def __init__(self, logger: Logger):
# Experimental features.
self.enable_zero_stage3_support = False # Once enabled, cannot be disabled.

self.deepcopy_before_model_export = True

# Override the feature config if it exists in os env.
self._override_from_env_vars()

Expand Down Expand Up @@ -367,3 +369,6 @@ def _override_from_env_vars(self):
# Experimental features.
if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1:
self.enable_zero_stage3_support = True

if "ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT" in os.environ:
self.deepcopy_before_model_export = int(os.getenv("ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT")) == 1

0 comments on commit 4bfa844

Please sign in to comment.