Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip module clone for preparing large model export #18663

Merged
merged 6 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,17 @@ data sparsity based performance optimizations.
export ORTMODULE_USE_EFFICIENT_ATTENTION=1
```

#### ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the module deep copy when preparing output data which will be used by ONNX export.
A classical usage of disabling the deep copy: when the deep copy before module export bring the memory peak, then we should disable it and have a try.

```bash
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=1 # Enable
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 # Disable
```

### 2.2 Memory Optimization

Q: *Want to run a bigger batch size?*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,30 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu

# Setup dynamic axes for onnx model
self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs)
need_deep_copy = self._runtime_options.deepcopy_before_model_export and _io.can_module_be_deep_cloned(
self._original_module, self._device
)
if not need_deep_copy:
if self._runtime_options.deepcopy_before_model_export:
self._logger.warning(
"Since the user requested not to deep copy this model, "
"the initial weights may not be preserved and could change slightly during the forward run. "
"This could cause a minor difference between the ORTModule and the PyTorch run for the "
"first iteration. The computation will proceed as normal, but this should be noted."
)
else:
self._logger.warning(
"Due to the limited GPU memory execution manager does not create a deep copy of this model. "
"Therefore, the initial weights might be slightly altered during the forward run. "
"This could result in a minor discrepancy between the ORTModule and the PyTorch run for the "
"first iteration. The computation will continue as usual, but this should be noted."
)
(
output_names,
output_dynamic_axes,
self._module_output_schema,
) = _io.parse_outputs_for_onnx_export_and_extract_schema(
self._original_module, inputs, kwargs, self._logger, self._device
self._original_module, inputs, kwargs, self._logger, self._device, need_deep_copy
)
self._input_info.dynamic_axes.update(output_dynamic_axes)

Expand Down
46 changes: 41 additions & 5 deletions orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,25 +543,61 @@ def _add_input(name, input_value, onnx_graph, onnx_graph_input_names):
)


def calculate_total_parameter_size_in_bytes(module: torch.nn.Module) -> int:
pengwa marked this conversation as resolved.
Show resolved Hide resolved
"""Calculate the total parameter size in bytes"""
total_size = 0
for p in module.parameters():
total_size += p.numel() * p.element_size()
return total_size


def can_module_be_deep_cloned(module: torch.nn.Module, device: Optional[torch.device]) -> bool:
"""Check if the module can be cloned

If the 2 times total module parameter size >= device memory, the module cannot be cloned.
> Initially there is one set of parameters;
> parse_outputs_for_onnx_export_and_extract_schema want to clone the full module including the frozen weight;
> PyTorch ONNX exporter will clone the trainable parameters;

So as long as the module can be cloned in parse_outputs_for_onnx_export_and_extract_schema, it is safe
to export the model without OOM. Here we return whether can clone the module in
parse_outputs_for_onnx_export_and_extract_schema.

Args:
module: The module to be cloned.
device: The device to be used for cloning.
"""

if device is None or device.type != "cuda":
return True

total_size = calculate_total_parameter_size_in_bytes(module)
return total_size * 2 < torch.cuda.get_device_properties(device).total_memory * 0.90 # give a 10% buffer


def parse_outputs_for_onnx_export_and_extract_schema(
module,
args: Sequence[ORTModelInputOutputType],
kwargs: Mapping[str, ORTModelInputOutputType],
logger: Logger,
device: Optional[torch.device],
clone_module: bool,
):
# Perform a forward call to grab outputs
output_names = None
output_dynamic_axes = None
is_deepcopy = False
deep_copied = False
logger.info("Running model forward to infer output schema and dynamic axes...")
with torch.no_grad():
# Deepcopy inputs, since input values may change after model run.
sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs)
pengwa marked this conversation as resolved.
Show resolved Hide resolved
try:
# Deepcopy model, in case model is stateful and changes after model run.
model_copy = copy.deepcopy(module)
is_deepcopy = True
if clone_module:
# Deepcopy model, in case model is stateful and changes after model run.
model_copy = copy.deepcopy(module)
deep_copied = True
else:
model_copy = module
except Exception:
model_copy = module
logger.warning(
Expand All @@ -576,7 +612,7 @@ def parse_outputs_for_onnx_export_and_extract_schema(
output_names, output_dynamic_axes = _parse_outputs_and_extract_names_and_dynamic_axes(sample_outputs)

output_schema = _extract_schema(sample_outputs, device)
if is_deepcopy:
if deep_copied:
del model_copy
gc.collect()
if torch.cuda.is_available():
Expand Down
5 changes: 5 additions & 0 deletions orttraining/orttraining/python/training/ortmodule/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def __init__(self, logger: Logger):
# Experimental features.
self.enable_zero_stage3_support = False # Once enabled, cannot be disabled.

self.deepcopy_before_model_export = True

# Override the feature config if it exists in os env.
self._override_from_env_vars()

Expand Down Expand Up @@ -367,3 +369,6 @@ def _override_from_env_vars(self):
# Experimental features.
if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1:
self.enable_zero_stage3_support = True

if "ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT" in os.environ:
self.deepcopy_before_model_export = int(os.getenv("ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT")) == 1
Loading