From 856da38aff47c85952168bebc604fbb9e4c94154 Mon Sep 17 00:00:00 2001 From: Peng Wang Date: Wed, 20 Dec 2023 19:29:56 -0800 Subject: [PATCH] refinement --- .../_custom_autograd_function_exporter.py | 27 +++- .../training/ortmodule/_runtime_inspector.py | 12 +- .../training/ortmodule/_training_manager.py | 2 +- .../ortmodule/_zero_stage3_compatibility.py | 3 - .../python/training/utils/hooks/__init__.py | 2 + .../utils/hooks/_mem_statistics_subscriber.py | 134 ++++++++++++++++++ .../utils/hooks/_statistics_subscriber.py | 6 +- .../training/utils/hooks/_subscriber_base.py | 84 +++++++---- .../utils/hooks/_subscriber_manager.py | 42 +++++- .../utils/hooks/_zero_offload_subscriber.py | 8 +- .../training/utils/torch_profile_utils.py | 47 +++++- 11 files changed, 317 insertions(+), 50 deletions(-) create mode 100644 orttraining/orttraining/python/training/utils/hooks/_mem_statistics_subscriber.py diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index f10416a9bb0f4..58f3e64bc6757 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -10,7 +10,7 @@ import torch import torch.utils.checkpoint -from onnx import ModelProto +from onnx import ModelProto, helper from packaging import version from torch.onnx import symbolic_helper @@ -393,6 +393,31 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model node.name = f"{op_name_prefix}_id_{index}" index += 1 + from onnxruntime.training.utils.hooks._mem_statistics_subscriber import _InspectMemoryUsage + from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation + from onnxruntime.training.utils.hooks._subscriber_manager import _IncrementStep + + _allowed_unsafe_run_python_op_names = [ + get_fully_qualified_class_name(_InspectMemoryUsage), + get_fully_qualified_class_name(_IncrementStep), + get_fully_qualified_class_name(_InspectActivation), + ] + + for node in exported_model.graph.node: + if node.op_type == "PythonOp": + func_name = None + safe_run_mode_attr = None + for attr in node.attribute: + if attr.name == "func_name": + func_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + if attr.name == "safe_run_mode": + safe_run_mode_attr = attr + + if func_name in _allowed_unsafe_run_python_op_names: + if safe_run_mode_attr: + node.attribute.remove(safe_run_mode_attr) + node.attribute.append(helper.make_attribute("safe_run_mode", 0)) + return exported_model diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 86760097dca8c..772b9bd9e31ae 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -509,6 +509,8 @@ def __init__(self, m: torch.nn.Module, logger: Logger): self._is_first_inspect = True + self._m = m + def is_enabled(self) -> bool: """Check if memory inspector is enabled.""" return self._is_enabled @@ -621,10 +623,12 @@ def inspect_memory(self, cur_phase: Phase): need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0) if need_print: - self._logger.info( - log_memory_usage( - _convert_phase_to_string(cur_phase), rank_0_only=True, step_info=f"step {self._current_step}" - ) + log_memory_usage( + _convert_phase_to_string(cur_phase), + rank_0_only=True, + step_info=f"step {self._current_step}", + logger=self._logger, + module=self._m, ) if cur_phase == self._last_phase: diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 5b2c673ce94cb..b76b473de1641 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -237,7 +237,7 @@ def forward(self, *inputs, **kwargs): # Only change this after the firs time a warning is issued. self._first_skip_check_warning = False self._logger.info( - "Fast path enabled - skipping checks.Rebuild graph: %s, Execution agent: %s, Device check: %s", + "Fast path enabled - skipping checks. Rebuild graph: %s, Execution agent: %s, Device check: %s", self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT), self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT), self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE), diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index ff110c431d300..4f3f693f70155 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -179,8 +179,6 @@ def _get_func_name(node: NodeProto) -> Optional[str]: exported_model.graph.node.insert(0, weight_pull_node) # Update safe_run_mode attribute for PythonOp. - from onnxruntime.training.utils.hooks._subscriber_manager import _IncrementStep - _allowed_unsafe_run_python_op_names = [ get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction), get_fully_qualified_class_name(ORTZeROOffloadPostForwardFunction), @@ -188,7 +186,6 @@ def _get_func_name(node: NodeProto) -> Optional[str]: DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME, DEEPSPEED_POST_BACKWARD_FUNCTION_NAME, DEEPSPEED_LINEAR_FUNCTION_NAME, - get_fully_qualified_class_name(_IncrementStep), ] for node in exported_model.graph.node: diff --git a/orttraining/orttraining/python/training/utils/hooks/__init__.py b/orttraining/orttraining/python/training/utils/hooks/__init__.py index 89c0d44abbb7a..7e9217578c224 100644 --- a/orttraining/orttraining/python/training/utils/hooks/__init__.py +++ b/orttraining/orttraining/python/training/utils/hooks/__init__.py @@ -8,12 +8,14 @@ __all__ = [ "StatisticsSubscriber", + "MemoryStatisticsSubscriber", "GlobalSubscriberManager", "inspect_activation", "ZeROOffloadSubscriber", "configure_ort_compatible_zero_stage3", ] +from ._mem_statistics_subscriber import MemoryStatisticsSubscriber from ._statistics_subscriber import StatisticsSubscriber, _InspectActivation from ._subscriber_manager import SubscriberManager from ._zero_offload_subscriber import ZeROOffloadSubscriber, configure_ort_compatible_zero_stage3 diff --git a/orttraining/orttraining/python/training/utils/hooks/_mem_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_mem_statistics_subscriber.py new file mode 100644 index 0000000000000..9ce06ac503316 --- /dev/null +++ b/orttraining/orttraining/python/training/utils/hooks/_mem_statistics_subscriber.py @@ -0,0 +1,134 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + + +from typing import List, Optional, Tuple, Union + +import onnx +import torch + +from onnxruntime.training.utils import log_memory_usage, extract_data_and_schema, unflatten_data_using_schema, ORTModelInputOutputType + + +from ._subscriber_base import RuntimeStates, SubscriberBase + + +_PRE_FW_PASS_PHASE = "pre-fw-pass" +_POST_FW_PASS_PHASE = "post-fw-pass" +_PRE_BW_PASS_PHASE = "pre-bw-pass" +_POST_BW_PASS_PHASE = "post-bw-pass" + +class _InspectMemoryUsage(torch.autograd.Function): + """This class is used to print the memory statistics in the forward and backward passes.""" + + @staticmethod + def forward(ctx, phase: str, run_ctx: RuntimeStates, module: torch.nn.Module, + *input_tensor_list: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: + """Make sure there is the same number of `tensor` inputs and outputs. + """ + ctx.current_step = run_ctx.global_states.execution_step + ctx.phase = phase + ctx.module = module + + assert ctx.phase in [_PRE_FW_PASS_PHASE, _POST_FW_PASS_PHASE], f"Invalid phase {ctx.phase}" + + # The step is not always consistent with the step in users' training loops. + # It is a counter of how many times the forward+backward pass is called. + log_memory_usage(f"{ctx.phase}", rank_0_only=True, step_info=f"step {ctx.current_step}", module=ctx.module) + + return tuple(t.detach().requires_grad_(t.requires_grad) for t in input_tensor_list) + + @staticmethod + def backward(ctx, *grad_output: Tuple[Optional[torch.Tensor], ...]) -> Tuple[Optional[torch.Tensor], ...]: + phase = ctx.phase + if ctx.phase == _PRE_FW_PASS_PHASE: + phase = _POST_BW_PASS_PHASE + elif ctx.phase == _POST_FW_PASS_PHASE: + phase = _PRE_BW_PASS_PHASE + log_memory_usage(f"{phase}", rank_0_only=True, step_info=f"step {ctx.current_step}", module=ctx.module) + return (None, None, None, *tuple(g for g in grad_output)) + + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes + + @staticmethod + def alias_input(node_proto_str: str): + node = onnx.NodeProto() + node.ParseFromString(node_proto_str) + non_tensor_fw_input_count = 3 + fw_output_count = len(node.output) - 1 # exclude the first output appended in ONNX + fw_alias_map = [-1] * fw_output_count + bw_alias_map = [-1] * (non_tensor_fw_input_count + len(node.input)) + + for i in range(fw_output_count): + fw_alias_map[i] = i + non_tensor_fw_input_count + + tensor_input_index = 0 + for i in range(len(bw_alias_map)): + if i < non_tensor_fw_input_count: + continue + bw_alias_map[i] = tensor_input_index + tensor_input_index += 1 + return fw_alias_map, bw_alias_map + + + +class MemoryStatisticsSubscriber(SubscriberBase): + """ + This subscriber is used to print the memory statistics in the forward and backward passes. + """ + + def __init__( + self, + start_step: Union[None, int] = None, + end_step: Union[None, int] = None, + ): + """ + Steps in [start_step, end_step) will run subscriber actions. + + Args: + start_step: the first step that runs subscriber actions. + end_step: the end step (exclusively) that runs subscriber actions. + """ + super().__init__(start_step=start_step, end_step=end_step) + + def pre_forward_outmost_module_apply_impl( + self, + run_ctx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + + flatten_args_tensor_list, args_schema = extract_data_and_schema(args) + flatten_kwargs_tensor_list, kwargs_schema = extract_data_and_schema(kwargs) + flatten_out = _InspectMemoryUsage.apply(_PRE_FW_PASS_PHASE, run_ctx, module, + *(flatten_args_tensor_list + flatten_kwargs_tensor_list)) + args_tensors = flatten_out[:len(flatten_args_tensor_list)] + kwargs_tensors = flatten_out[len(flatten_args_tensor_list):] + restored_args = unflatten_data_using_schema(args_tensors, args_schema) + restored_kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema) + + return restored_args, restored_kwargs + + + def post_forward_outmost_module_apply_impl( + self, + run_ctx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + + flatten_output_tensor_list, output_schema = extract_data_and_schema(outputs) + output_tensors = _InspectMemoryUsage.apply(_POST_FW_PASS_PHASE, run_ctx, module, *flatten_output_tensor_list) + restored_outputs = unflatten_data_using_schema(output_tensors, output_schema) + + return args, restored_outputs diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index c5be17236ac06..35f8ada6507fe 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -156,12 +156,12 @@ def __init__( ) def post_forward_tensor_apply_impl( - self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + self, run_ctx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor ) -> torch.Tensor: - module_index = run_rtx.global_states.module_to_module_index[module] + module_index = run_ctx.global_states.module_to_module_index[module] name = f"{module.__class__.__name__}_{module_index}_{tensor_index}th_output" return _InspectActivation.apply( - name, module_index, run_rtx, tensor, self.module_post_forward_impl, self.module_pre_backward_impl + name, module_index, run_ctx, tensor, self.module_post_forward_impl, self.module_pre_backward_impl ) def module_post_forward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int): diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py index 1b9a6fc91ec3c..59286d2c8f9d7 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py @@ -62,7 +62,7 @@ def __init__(self, start_step: Optional[int], end_step: Optional[int]): def pre_forward_module_apply( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, kwargs: ORTModelInputOutputType, @@ -70,7 +70,7 @@ def pre_forward_module_apply( """This function is called inside the nn.Module's pre-forward hook. Args: - run_rtx (RuntimeStates): The runtime states of SubscriberManager. + run_ctx (RuntimeStates): The runtime states of SubscriberManager. module (torch.nn.Module): The module that is being executed. args (ORTModelInputOutputType): The positional arguments that are passed to the module's pre-forward hook. kwargs (ORTModelInputOutputType): The keyword arguments that are passed to the module's pre-forward hook. @@ -79,15 +79,15 @@ def pre_forward_module_apply( Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated args and kwargs. """ - if self._need_skip_step(run_rtx.global_states.execution_step): + if self._need_skip_step(run_ctx.global_states.execution_step): return args, kwargs - updated_args, updated_kwargs = self.pre_forward_module_apply_impl(run_rtx, module, args, kwargs) + updated_args, updated_kwargs = self.pre_forward_module_apply_impl(run_ctx, module, args, kwargs) return updated_args, updated_kwargs def pre_forward_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, kwargs: ORTModelInputOutputType, @@ -95,29 +95,29 @@ def pre_forward_module_apply_impl( return args, kwargs def pre_forward_tensor_apply( - self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + self, run_ctx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor ) -> torch.Tensor: """This function is called inside the nn.Module's pre-forward hook. Args: - run_rtx (RuntimeStates): The runtime states of SubscriberManager. + run_ctx (RuntimeStates): The runtime states of SubscriberManager. module (torch.nn.Module): The module that is being executed. tensor_index (int): The index of the tensor in the input tensor list. tensor (torch.Tensor): The tensor is one of module's forward inputs. """ - if self._need_skip_step(run_rtx.global_states.execution_step): + if self._need_skip_step(run_ctx.global_states.execution_step): return tensor - return self.pre_forward_tensor_apply_impl(run_rtx, module, tensor_index, tensor) + return self.pre_forward_tensor_apply_impl(run_ctx, module, tensor_index, tensor) def pre_forward_tensor_apply_impl( - self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + self, run_ctx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor ) -> torch.Tensor: return tensor def post_forward_module_apply( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, @@ -125,7 +125,7 @@ def post_forward_module_apply( """This function is called inside the nn.Module's post-forward hook. Args: - run_rtx (RuntimeStates): The runtime states of SubscriberManager. + run_ctx (RuntimeStates): The runtime states of SubscriberManager. module (torch.nn.Module): The module that is being executed. args (ORTModelInputOutputType): The inputs arguments that are passed to the module's post-forward hook as input. @@ -135,14 +135,14 @@ def post_forward_module_apply( Returns: Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated inputs and outputs. """ - if self._need_skip_step(run_rtx.global_states.execution_step): + if self._need_skip_step(run_ctx.global_states.execution_step): return args, outputs - return self.post_forward_module_apply_impl(run_rtx, module, args, outputs) + return self.post_forward_module_apply_impl(run_ctx, module, args, outputs) def post_forward_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, @@ -150,12 +150,12 @@ def post_forward_module_apply_impl( return args, outputs def post_forward_tensor_apply( - self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + self, run_ctx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor ) -> torch.Tensor: """This function is called inside the nn.Module's post-forward hook. Args: - run_rtx (RuntimeStates): The runtime states of SubscriberManager. + run_ctx (RuntimeStates): The runtime states of SubscriberManager. module (torch.nn.Module): The module that is being executed. tensor_index (int): The index of the tensor in the output tensor list. tensor (torch.Tensor): The tensor is one of module's forward outputs. @@ -163,19 +163,53 @@ def post_forward_tensor_apply( Returns: torch.Tensor: Updated tensor. """ - if self._need_skip_step(run_rtx.global_states.execution_step): + if self._need_skip_step(run_ctx.global_states.execution_step): return tensor - return self.post_forward_tensor_apply_impl(run_rtx, module, tensor_index, tensor) + return self.post_forward_tensor_apply_impl(run_ctx, module, tensor_index, tensor) def post_forward_tensor_apply_impl( - self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + self, run_ctx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor ) -> torch.Tensor: return tensor + def pre_forward_outmost_module_apply( + self, + run_ctx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is called inside the nn.Module's pre-forward hook. + + Args: + run_ctx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + args (ORTModelInputOutputType): The positional arguments that are passed to the module's pre-forward hook. + kwargs (ORTModelInputOutputType): The keyword arguments that are passed to the module's pre-forward hook. + + Returns: + Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated args and kwargs. + + """ + if self._need_skip_step(run_ctx.global_states.execution_step): + return args, kwargs + + updated_args, updated_kwargs = self.pre_forward_outmost_module_apply_impl(run_ctx, module, args, kwargs) + return updated_args, updated_kwargs + + def pre_forward_outmost_module_apply_impl( + self, + run_ctx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + return args, kwargs + def post_forward_outmost_module_apply( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, @@ -183,7 +217,7 @@ def post_forward_outmost_module_apply( """This function is called inside the outmost nn.Module's post-forward hook. Args: - run_rtx (RuntimeStates): The runtime states of SubscriberManager. + run_ctx (RuntimeStates): The runtime states of SubscriberManager. module (torch.nn.Module): The module that is being executed. args (ORTModelInputOutputType): The inputs arguments that are passed to the module's post-forward hook as input. @@ -193,14 +227,14 @@ def post_forward_outmost_module_apply( Returns: Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated inputs and outputs. """ - if self._need_skip_step(run_rtx.global_states.execution_step): + if self._need_skip_step(run_ctx.global_states.execution_step): return args, outputs - return self.post_forward_outmost_module_apply_impl(run_rtx, module, args, outputs) + return self.post_forward_outmost_module_apply_impl(run_ctx, module, args, outputs) def post_forward_outmost_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index c9c06dabab4de..1656d4eac3d7c 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -155,6 +155,9 @@ def _initialize(self, module: torch.nn.Module): raise RuntimeError("No subscribers are registered.") def _pre_forward_outmost_module_hook(module, module_inputs): + return _pre_forward_outmost_module_with_kwarg_hook(module, module_inputs, {}) + + def _pre_forward_outmost_module_with_kwarg_hook(module, module_inputs, kwargs): # This check is to support the case where module is first registered in the subscriber manager, # then the module and hook are copied, when new module instance runs to the hook, the global states # are not reset, so the logic depends on the global states will fail. So in the outer-most pre-forward hook @@ -168,9 +171,20 @@ def _pre_forward_outmost_module_hook(module, module_inputs): "Initialize global states for the first time, this should only happen once for each outmost module." ) self._initialize_one_time_global_states(module) + + # Call pre outmost module forward custom actions for subscribers + for sub in self._subscribers: + module_inputs = sub.pre_forward_outmost_module_apply(self._run_ctx, module, module_inputs, kwargs) + return module_inputs - module.register_forward_pre_hook(_pre_forward_outmost_module_hook) + # "with_kwargs" is not available for low versions of PyTorch. + if "with_kwargs" in inspect.signature(module.register_forward_pre_hook).parameters: + self._pre_forward_hooks.append( + module.register_forward_pre_hook(_pre_forward_outmost_module_with_kwarg_hook, with_kwargs=True) + ) + else: + self._pre_forward_hooks.append(module.register_forward_pre_hook(_pre_forward_outmost_module_hook)) next_module_index = [0] self._register_hooks_recursively(module, 1, next_module_index) @@ -189,7 +203,7 @@ def _post_forward_outmost_module_hook(module, module_inputs, module_outputs): return restored_outputs - module.register_forward_hook(_post_forward_outmost_module_hook) + self._pre_forward_hooks.append(module.register_forward_hook(_post_forward_outmost_module_hook)) def _initialize_one_time_global_states(self, module: torch.nn.Module): def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: List[int]): @@ -244,6 +258,18 @@ def _pre_forward_module_with_kwargs_hook(module, module_inputs, kwargs): for sub in self._subscribers: module_inputs, kwargs = sub.pre_forward_module_apply(self._run_ctx, module, module_inputs, kwargs) + if len(self._subscribers) == 0: + return module_inputs, kwargs + + # If there is no tensor level post forward func override, we can skip the following tensor level hook. + if all( + [ + sub.__class__.pre_forward_tensor_apply_impl == SubscriberBase.pre_forward_tensor_apply_impl + for sub in self._subscribers + ] + ): + return module_inputs, kwargs + # Tensor level hook flatten_positional_input_tensor_list, input_schema = extract_data_and_schema(module_inputs) flatten_keyword_input_tensor_list, keyword_input_schema = extract_data_and_schema(kwargs) @@ -272,6 +298,18 @@ def _post_forward_module_hook(module, module_inputs, module_outputs): for sub in self._subscribers: _, module_outputs = sub.post_forward_module_apply(self._run_ctx, module, module_inputs, module_outputs) + if len(self._subscribers) == 0: + return module_outputs + + # If there is no tensor level post forward func override, we can skip the following tensor level hook. + if all( + [ + sub.__class__.post_forward_tensor_apply_impl == SubscriberBase.post_forward_tensor_apply_impl + for sub in self._subscribers + ] + ): + return module_outputs + # Tensor level hook flatten_output_tensor_list, output_schema = extract_data_and_schema(module_outputs) for sub in self._subscribers: diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index e6004319ef5ea..17008e6f3ce01 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -170,8 +170,6 @@ def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="./", sta torch.nn.functional.linear = _zero3_linear_wrap_ort_compatible except ImportError as e: - warnings.warn(f"DeepSpeed import error {e}") - def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, stats_overwrite=False): raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.") @@ -476,7 +474,7 @@ def __init__(self, offloader, one_time_init: _ZeROOffloadOneTimeInitializer, ena @nvtx_function_decorator def pre_forward_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, kwargs: ORTModelInputOutputType, @@ -552,7 +550,7 @@ def _wrap_pre_forward_module_hook(module): @nvtx_function_decorator def post_forward_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, @@ -611,7 +609,7 @@ def _wrap_post_forward_module_hook(module, input, outputs): @nvtx_function_decorator def post_forward_outmost_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, diff --git a/orttraining/orttraining/python/training/utils/torch_profile_utils.py b/orttraining/orttraining/python/training/utils/torch_profile_utils.py index e3ebb8ed22ea5..b7a0976967e30 100644 --- a/orttraining/orttraining/python/training/utils/torch_profile_utils.py +++ b/orttraining/orttraining/python/training/utils/torch_profile_utils.py @@ -32,13 +32,21 @@ def wrapped_fn(*args, **kwargs): return wrapped_fn -def log_memory_usage(cur_phase: str, rank_0_only=True, step_info="") -> str: +def log_memory_usage(cur_phase: str, rank_0_only=True, step_info="", logger=None, module=None): + """Log memory usage for the current phase. + Args: + cur_phase (str): The current phase. + rank_0_only (bool, optional): Only log the memory usage for rank 0. Defaults to True. + step_info (str, optional): The step information. Defaults to "". + logger (logging.Logger, optional): The logger to log the memory usage. Defaults to None, which means print to stdout. + module (torch.nn.Module, optional): The module to get parameter, buffer and grad sizes. Defaults to None. + """ rank = 0 if rank_0_only is True: if torch.distributed.is_initialized(): rank = torch.distributed.get_rank() if rank != 0: - return "" + return _normalizer_factor = float(1024 * 1024) _normalizer_unit = "MiB" @@ -62,21 +70,48 @@ def output_to_list(x): memory_use_info = output_to_list(sp.check_output(nvm_cmd.split(), stderr=sp.STDOUT))[1:] except sp.CalledProcessError as e: raise RuntimeError(f"command '{e.cmd}' return with error (code {e.returncode}): {e.output}") from None - memory_use_values = [str(x.split()[0]) for i, x in enumerate(memory_use_info)] + memory_use_value = [int(x.split()[0]) for i, x in enumerate(memory_use_info)][rank] mem_stats = [ ["phase", cur_phase], + ["nvm smi", memory_use_value], ["allocated", cur_mem_allocated], # current memory allocated for tensors ["max allocated", max_mem_allocated], # peak memory allocated for tensors ["cached", cur_mem_cached], # current memory cached for the caching allocator ["max cached", max_mem_cached], # peak memory cached for caching allocator. ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory - ["nvm smi", ",".join(memory_use_values)], ] - summ = f"{rank} {step_info} memory ({_normalizer_unit})" + # Calculate the total size of parameters and gradients in the model + if module: + param_total_size = 0 + grad_total_size = 0 + for p in module.parameters(): + if p.is_cuda: + param_total_size += p.numel() * p.element_size() + if p.grad is not None and p.grad.is_cuda: + grad_total_size += p.grad.numel() * p.grad.element_size() + + # Calculate the total size of buffers in the model + buffer_total_size = 0 + for b in module.buffers(): + if b.is_cuda: + buffer_total_size += b.numel() * b.element_size() + + mem_stats.extend( + [ + ["param size", _normalize(param_total_size)], + ["grad size", _normalize(grad_total_size)], + ["buffer size", _normalize(buffer_total_size)], + ] + ) + + summ = f"rank-{rank} {step_info} memory ({_normalizer_unit})" for stat in mem_stats: summ += f" | {stat[0]}: {stat[1]}" - return summ + if logger is None: + print(summ) + else: + logger.info(summ)