Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model post process for zero stage3 training #17187

Merged
merged 10 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class PythonOpShapeInferStore:

@classmethod
def register(cls, kclass: torch.autograd.Function) -> None:
"""Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined.
"""Register a shape inference function for a torch.autograd.Function if there is staticmethod
"infer_shape" defined.
The signature of the shape inference function should be:
@staticmethod
Expand All @@ -51,6 +52,11 @@ def infer_shape(
if hasattr(kclass, "infer_shape") and kclass_name not in cls._CLASS_MAP:
cls._CLASS_MAP[kclass_name] = kclass.infer_shape

@classmethod
def register_func(cls, name: str, func: Callable) -> None:
"""Register a shape inference function for a torch.autograd.Function by name."""
cls._CLASS_MAP[name] = func

@classmethod
def get_shape_infer(cls, name: str) -> Optional[Callable]:
return cls._CLASS_MAP.get(name, None)
Expand Down Expand Up @@ -228,9 +234,9 @@ def _export_pt_1_10(g, n, *args, **kwargs):
input_float_tuples.extend(list(arg))
continue

is_inspect_activation = (
func_full_qual_name == "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation"
)
from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation

is_inspect_activation = func_full_qual_name == get_fully_qualified_class_name(_InspectActivation)
if is_inspect_activation and isinstance(arg, str):
# _InspectActivation is a special case where the first argument is a string
# that is used to determine the activation name to be inspected.
Expand Down Expand Up @@ -307,14 +313,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
_export = wrap_custom_export_function(_export_pt_1_10)


def _post_process_after_export(exported_model: ModelProto, enable_custom_autograd_function: bool) -> ModelProto:
"""Post process the exported model."""
if enable_custom_autograd_function:
exported_model = _post_process_enabling_autograd_function(exported_model)
return exported_model


def _post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto:
def post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto:
# Loop all PythonOp, append "_ctx" as the first output.
index = 0
for node in exported_model.graph.node:
Expand All @@ -330,8 +329,7 @@ def _post_process_enabling_autograd_function(exported_model: ModelProto) -> Mode
op_name_prefix = kclass_name
break

if not node.name:
node.name = f"{op_name_prefix}_id_{index}"
index += 1
node.name = f"{op_name_prefix}_id_{index}"
index += 1

return exported_model
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ def wrap_all_outputs(result):
result = backward_function(*wrapped_args)

# Extract results as DLPack tensor list.
if isinstance(result, torch.Tensor):
result = [result]
elif isinstance(result, (tuple, list)):
result = list(result)
else:
raise wrap_exception(
ORTModuleIOError,
TypeError(f"ORTModule does not support the following model output type {type(result)}."),
)

wrapped_returned_args = wrap_all_outputs(result)

torch_interop_utils.unregister_grad_fn(id(ctx))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
import onnxruntime
from onnxruntime.capi import _pybind_state as C
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from onnxruntime.training.utils import ORTModelInputOutputSchemaType
from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch
from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3

from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
from ._custom_autograd_function_exporter import _post_process_after_export
from ._fallback import (
ORTModuleDeviceException,
ORTModuleONNXModelException,
Expand Down Expand Up @@ -141,9 +140,14 @@ def __init__(

register_triton_op_executor()

self._zero_stage3_param_map = {}
if self._runtime_options.enable_zero_stage3_support:
# Cannot toggle feature enabling/disabling after the first time enabled.
configure_ort_compatible_zero_stage3()
from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params

self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module)

configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True)

def _get_torch_gpu_allocator_function_addresses(self):
if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available():
Expand Down Expand Up @@ -345,7 +349,8 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu
)
if os.path.exists(cache_dir) and os.path.isfile(filename):
self._logger.info(
f"Cached model detected! Cached model will be used to save export and initialization time. If you want the model to be re-exported then DELETE {filename}."
f"Cached model detected! Cached model will be used to save export and initialization time."
f"If you want the model to be re-exported then DELETE {filename}."
)
exported_model = onnx.load(filename)
return exported_model
Expand Down Expand Up @@ -409,9 +414,24 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu
)
exported_model = onnx.load_model_from_string(f.getvalue())

exported_model = _post_process_after_export(
exported_model, self._runtime_options.enable_custom_autograd_function
)
if self._runtime_options.enable_custom_autograd_function:
from ._custom_autograd_function_exporter import post_process_enabling_autograd_function

exported_model = post_process_enabling_autograd_function(exported_model)

if self._runtime_options.enable_zero_stage3_support:
from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat

exported_model = post_processing_enable_zero_stage3_compat(
exported_model,
self._zero_stage3_param_map,
[name for name, _ in self._flattened_module.named_parameters()],
)

# Cannot append pull weight trigger name to input names as following, otherwise, the later check (
# https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18)
# find input info mismatch, will re-initialize the graph builder.
# self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME)
pengwa marked this conversation as resolved.
Show resolved Hide resolved

# Cache model for future runs
if cache_dir:
Expand Down Expand Up @@ -477,7 +497,14 @@ def _initialize_graph_builder(self):
grad_builder_config = C.OrtModuleGraphBuilderConfiguration()
grad_builder_config.initializer_names = initializer_names
grad_builder_config.initializer_names_to_train = initializer_names_to_train
grad_builder_config.input_names_require_grad = self._input_info.require_grad_names

input_names_require_grad = self._input_info.require_grad_names
if self._runtime_options.enable_zero_stage3_support:
from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME

# Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph.
input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME)
grad_builder_config.input_names_require_grad = input_names_require_grad
grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING
grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(
Expand Down Expand Up @@ -553,6 +580,9 @@ def _enable_conditional_optimizations(
inputs, kwargs
)

if self._runtime_options.enable_zero_stage3_support:
self._append_pull_weight_trigger_as_input(kwargs, detected_device)

_, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers(
self._graph_initializers,
self._graph_builder.get_graph_info().user_input_names,
Expand All @@ -562,6 +592,7 @@ def _enable_conditional_optimizations(
kwargs,
detected_device,
self._runtime_inspector,
self._zero_stage3_param_map,
)

# Enable sparsity-based optimization when applicable.
Expand All @@ -587,6 +618,21 @@ def _enable_conditional_optimizations(
if self._runtime_options.print_memory_stat:
self._runtime_inspector.enable_memory_inspector(self._original_module)

def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device):
from ._zero_stage3_compatibility import (
STAGE3_PULL_WEIGHT_TRIGGER_NAME,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
)

kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros(
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()

return kwargs

def _log_feature_stats(self):
if get_rank() != 0:
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def forward(self, *inputs, **kwargs):
# Assert that the input and model device match
_utils._check_same_device(self._device, "Input argument to forward", *inputs)

if self._runtime_options.enable_zero_stage3_support:
self._append_pull_weight_trigger_as_input(kwargs, self._device)

prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
self._graph_initializers,
self._graph_info.user_input_names,
Expand All @@ -168,6 +171,7 @@ def forward(self, *inputs, **kwargs):
kwargs,
self._device,
self._runtime_inspector,
self._zero_stage3_param_map,
)

user_outputs, _ = InferenceManager.execution_session_run_forward(
Expand Down
8 changes: 7 additions & 1 deletion orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def _combine_input_buffers_initializers(
kwargs: Mapping[str, ORTModelInputOutputType],
device: torch.device,
rt_inspector: RuntimeInspector,
zero_stage3_offload_param_map: Optional[Dict[str, torch.nn.parameter.Parameter]],
):
"""Creates forward `*inputs` list from user input and PyTorch initializers
Expand Down Expand Up @@ -254,7 +255,12 @@ def _expand_inputs(current_input, non_none_inputs, name=""):
)

# params is a list of all initializers known to the onnx graph
result.extend(params)
if zero_stage3_offload_param_map:
for p in params:
if p not in zero_stage3_offload_param_map.values():
result.append(p)
else:
result.extend(params)

return result, embed_sparsity_results, label_sparsity_results

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ def forward(self, *inputs, **kwargs):

self._gradient_accumulation_manager.maybe_update_cache_before_run()

if self._runtime_options.enable_zero_stage3_support:
self._append_pull_weight_trigger_as_input(kwargs, self._device)

prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
self._graph_initializers,
self._graph_info.user_input_names,
Expand All @@ -320,6 +323,7 @@ def forward(self, *inputs, **kwargs):
kwargs,
self._device,
self._runtime_inspector,
self._zero_stage3_param_map,
)

outputs = unflatten_user_output(
Expand Down
Loading