diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 9b77832abb6f1..55e1215810cca 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -485,12 +485,15 @@ void ListAllCombinations(const InlinedVector> new_combination = current_combination; - new_combination.push_back(plan); - ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations); - } + const InlinedVector>>& + plan_combination_list_at_cur_index = all_possible_node_optimization_plans[index]; + // For the index-th reused buffer, iterate all possible complete plans. + for (size_t i = 0; i < plan_combination_list_at_cur_index.size(); ++i) { + const auto& plan_combination = plan_combination_list_at_cur_index[i]; + InlinedVector> new_combination = current_combination; + // Append the chosen complete plan and continue exploring the next reused buffer by index + 1. + new_combination.insert(new_combination.end(), plan_combination.begin(), plan_combination.end()); + ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations); } MO_LOG_DEBUG_INFO(logger, "Exit ListAllCombinations"); @@ -520,17 +523,29 @@ void IterateNodeOptimizationPlan(const std::shared_ptr } InlinedVector>>> - all_possible_node_optimization_plans; - all_possible_node_optimization_plans.resize(plan->reuse_buffers.size()); + all_possible_node_optimization_plans(plan->reuse_buffers.size()); size_t i = 0; for (const auto& p : plan->reuse_buffers) { MO_LOG_DEBUG_INFO(logger, ">>>reuse buffer: " + std::to_string(p.first)); - IterateNode(p.second.first, node_to_optimization_plans_map, {}, logger, all_possible_node_optimization_plans[i]); + + // If the resued node is part of current node optimization plan, then we just add current combination to the result. + if (plan->GetOptimizationType() == OptimizationType::RecomputeWithCompromise || plan->GetOptimizationType() == OptimizationType::Recompute) { + const auto& recompute_subgraph = + dynamic_cast(plan.get())->GetNodesInTopoOrder(); + if (std::find(recompute_subgraph.begin(), recompute_subgraph.end(), p.second.first) != recompute_subgraph.end()) { + all_possible_node_optimization_plans[i].push_back(current_combination); + } + } + + if (all_possible_node_optimization_plans[i].size() == 0) { + IterateNode(p.second.first, node_to_optimization_plans_map, current_combination, logger, all_possible_node_optimization_plans[i]); + } + ++i; } - ListAllCombinations(all_possible_node_optimization_plans, 0, current_combination, logger, all_combinations); + ListAllCombinations(all_possible_node_optimization_plans, 0, {}, logger, all_combinations); MO_LOG_DEBUG_INFO(logger, "Exit IterateNodeOptimizationPlan: " + plan->GetClusterId()); } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc index 64e99a4a0bca5..4ce896c5350b0 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc @@ -15,35 +15,6 @@ namespace onnxruntime::optimizer::memory_optimizer { -std::string NodeOptimizationPlanBase::GetMemorySavingSymbolicString() const { - std::string saving_str; - for (auto output_index : activation_output_indices_) { - // If the output is reusing other node's buffer, then no memory saving. - if (reuse_buffers.find(output_index) != reuse_buffers.end()) { - continue; - } - - const auto& output_def = node->OutputDefs()[output_index]; - MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto()); - ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ", - DataTypeImpl::ToString(ml_data_type)); - const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); - ORT_ENFORCE(nullptr != tensor_type_base); - MLDataType elt_type = tensor_type_base->GetElementType(); - const auto byte_count_per_element = elt_type->Size(); - if (!saving_str.empty()) { - saving_str += " + "; - } - saving_str = "(" + GetActivationOutputDimParamString(output_index) + " * " + - std::to_string(byte_count_per_element) + " * " + - std::to_string(GetSaveRatio()) + ")"; - } - if (saving_str.empty()) { - return saving_str; - } - return "(" + saving_str + ")"; -} - Status MemoryOptimizationPlanner::UpdateNodePlansFromExecutionPlan(const GraphViewer& graph_viewer, const OrtValueNameIdxMap& ortvalue_name_to_idx_map, const SequentialExecutionPlan& p_seq_exec_plan) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h index c585b2810b39d..789f530b29f1d 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h @@ -83,7 +83,7 @@ class NodeOptimizationPlanBase { /** * Get a symbolic string to represent the memory saving for this optimization plan. */ - std::string GetMemorySavingSymbolicString() const; + virtual std::string GetMemorySavingSymbolicString() const = 0; std::string GetActivationOutputDimParamString(size_t index) const { ORT_ENFORCE(activation_output_dim_params_.find(index) != activation_output_dim_params_.end(), diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 52dea571a1eaf..18fa785ea7c45 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -72,11 +72,13 @@ const InlinedHashMap& GetAllowedRecompu {"Add", AllowedRecomputeNodeConfig{{0, 1}}}, {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, {"Div", AllowedRecomputeNodeConfig{{0, 1}}}, + {"Equal", AllowedRecomputeNodeConfig{{0, 1}}}, {"Mul", AllowedRecomputeNodeConfig{{0, 1}}}, {"Sub", AllowedRecomputeNodeConfig{{0, 1}}}, // Data layout /// The shape input is trivial whether it exists or not in backward. + {"Shape", AllowedRecomputeNodeConfig{{0}}}, {"Reshape", AllowedRecomputeNodeConfig{{0}}}, {"Squeeze", AllowedRecomputeNodeConfig{{0}}}, {"Transpose", AllowedRecomputeNodeConfig{{0}}}, @@ -92,6 +94,7 @@ const InlinedHashMap& GetAllowedRecompu {"Expand", AllowedRecomputeNodeConfig{{0}}}, {"FastGelu", AllowedRecomputeNodeConfig{{0}}}, {"Gelu", AllowedRecomputeNodeConfig{{0}}}, + {"QuickGelu", AllowedRecomputeNodeConfig{{0}}}, // Ternary elementwise {"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}}, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index d9693835313b8..ab114d970191e 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -86,6 +86,51 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase { std::string GetNodesInTopoOrderStr() const; + std::string GetMemorySavingSymbolicString() const override { + std::string saving_str; + for (auto output_index : GetActivationOutputIndices()) { + // If the output is reusing other node's buffer, then no memory saving. + std::string cur_output_saving_str; + + bool is_reused = reuse_buffers.find(output_index) != reuse_buffers.end(); + bool is_src_node_in_cur_node_subgraph = false; + if (is_reused) { + // Here we assume the src_node is the real owner of the buffer, so we don't need trace further. + const auto* src_node = reuse_buffers.at(output_index).first; + is_src_node_in_cur_node_subgraph = std::find(nodes_in_topological_order_.begin(), + nodes_in_topological_order_.end(), + src_node) != nodes_in_topological_order_.end(); + } + + if (!is_reused || is_src_node_in_cur_node_subgraph) { + // For is_src_node_in_cur_node_subgraph is True, still use the output to calculate the saving, because + // reusing buffer is the same size. + const auto& output_def = node->OutputDefs()[output_index]; + MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto()); + ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ", + DataTypeImpl::ToString(ml_data_type)); + const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); + ORT_ENFORCE(nullptr != tensor_type_base); + MLDataType elt_type = tensor_type_base->GetElementType(); + const auto byte_count_per_element = elt_type->Size(); + cur_output_saving_str = GetActivationOutputDimParamString(output_index) + " * " + + std::to_string(byte_count_per_element) + " * " + + std::to_string(GetSaveRatio()); + } else { + cur_output_saving_str = "0"; + } + + if (!saving_str.empty()) { + saving_str += " + "; + } + + saving_str = "(" + cur_output_saving_str + ")"; + } + + ORT_ENFORCE(!saving_str.empty(), "saving_str should not be empty for node: ", node->OpType(), " ", node->Name()); + return "(" + saving_str + ")"; + } + private: bool compromise_recompute_; InlinedVector nodes_in_topological_order_; diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index f10416a9bb0f4..58f3e64bc6757 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -10,7 +10,7 @@ import torch import torch.utils.checkpoint -from onnx import ModelProto +from onnx import ModelProto, helper from packaging import version from torch.onnx import symbolic_helper @@ -393,6 +393,31 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model node.name = f"{op_name_prefix}_id_{index}" index += 1 + from onnxruntime.training.utils.hooks._mem_statistics_subscriber import _InspectMemoryUsage + from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation + from onnxruntime.training.utils.hooks._subscriber_manager import _IncrementStep + + _allowed_unsafe_run_python_op_names = [ + get_fully_qualified_class_name(_InspectMemoryUsage), + get_fully_qualified_class_name(_IncrementStep), + get_fully_qualified_class_name(_InspectActivation), + ] + + for node in exported_model.graph.node: + if node.op_type == "PythonOp": + func_name = None + safe_run_mode_attr = None + for attr in node.attribute: + if attr.name == "func_name": + func_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + if attr.name == "safe_run_mode": + safe_run_mode_attr = attr + + if func_name in _allowed_unsafe_run_python_op_names: + if safe_run_mode_attr: + node.attribute.remove(safe_run_mode_attr) + node.attribute.append(helper.make_attribute("safe_run_mode", 0)) + return exported_model diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 078ce4d27cd6f..772b9bd9e31ae 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -14,7 +14,7 @@ from sympy import Symbol, simplify from sympy.parsing.sympy_parser import parse_expr -from onnxruntime.training.utils import PTable +from onnxruntime.training.utils import PTable, log_memory_usage from ._execution_agent import TrainingAgent from .options import _MemoryOptimizationLevel, _RuntimeOptions @@ -509,6 +509,8 @@ def __init__(self, m: torch.nn.Module, logger: Logger): self._is_first_inspect = True + self._m = m + def is_enabled(self) -> bool: """Check if memory inspector is enabled.""" return self._is_enabled @@ -621,29 +623,13 @@ def inspect_memory(self, cur_phase: Phase): need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0) if need_print: - cur_mem_allocated = self._normalize(torch.cuda.memory_allocated()) - max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated()) - cur_mem_cached = self._normalize(torch.cuda.memory_reserved()) - max_mem_cached = self._normalize(torch.cuda.max_memory_reserved()) - torch_mem_stat = torch.cuda.memory_stats() - cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) - max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) - - mem_stats = [ - ["phase", _convert_phase_to_string(cur_phase)], - ["allocated", cur_mem_allocated], # current memory allocated for tensors - ["max allocated", max_mem_allocated], # peak memory allocated for tensors - ["cached", cur_mem_cached], # current memory cached for the caching allocator - ["max cached", max_mem_cached], # peak memory cached for caching allocator. - ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory - ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory - ] - - summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})" - for stat in mem_stats: - summ += f" | {stat[0]}: {stat[1]}" - - self._logger.info(summ) + log_memory_usage( + _convert_phase_to_string(cur_phase), + rank_0_only=True, + step_info=f"step {self._current_step}", + logger=self._logger, + module=self._m, + ) if cur_phase == self._last_phase: self._increase_step() @@ -655,9 +641,6 @@ def inspect_memory(self, cur_phase: Phase): def _increase_step(self): self._current_step += 1 - def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str: - return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}" - def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]: mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map) diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 5b2c673ce94cb..b76b473de1641 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -237,7 +237,7 @@ def forward(self, *inputs, **kwargs): # Only change this after the firs time a warning is issued. self._first_skip_check_warning = False self._logger.info( - "Fast path enabled - skipping checks.Rebuild graph: %s, Execution agent: %s, Device check: %s", + "Fast path enabled - skipping checks. Rebuild graph: %s, Execution agent: %s, Device check: %s", self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT), self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT), self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE), diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index ff110c431d300..4f3f693f70155 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -179,8 +179,6 @@ def _get_func_name(node: NodeProto) -> Optional[str]: exported_model.graph.node.insert(0, weight_pull_node) # Update safe_run_mode attribute for PythonOp. - from onnxruntime.training.utils.hooks._subscriber_manager import _IncrementStep - _allowed_unsafe_run_python_op_names = [ get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction), get_fully_qualified_class_name(ORTZeROOffloadPostForwardFunction), @@ -188,7 +186,6 @@ def _get_func_name(node: NodeProto) -> Optional[str]: DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME, DEEPSPEED_POST_BACKWARD_FUNCTION_NAME, DEEPSPEED_LINEAR_FUNCTION_NAME, - get_fully_qualified_class_name(_IncrementStep), ] for node in exported_model.graph.node: diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc index 9e24022b8448d..599bdf813907b 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc @@ -181,7 +181,6 @@ py::object finalize_training_mode_forward( } if (kernel_info.is_first_run) { - std::cout << "666666666666666666666666. py_fn->materialize_grads:" << py_fn->materialize_grads << std::endl; get_materialize_grads_once(forward_output_tensors, py_fn->materialize_grads, kernel_info); if (kernel_info.safe_run_enabled) { diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index b4a518d573998..ecfb7d7907f3c 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -12,6 +12,7 @@ unflatten_data_using_schema, ) from onnxruntime.training.utils.torch_profile_utils import ( + log_memory_usage, nvtx_function_decorator, torch_nvtx_range_pop, torch_nvtx_range_push, @@ -31,6 +32,7 @@ "torch_nvtx_range_push", "torch_nvtx_range_pop", "nvtx_function_decorator", + "log_memory_usage", "pytorch_type_to_onnx_dtype", "onnx_dtype_to_pytorch_dtype", "pytorch_scalar_type_to_pytorch_dtype", diff --git a/orttraining/orttraining/python/training/utils/hooks/__init__.py b/orttraining/orttraining/python/training/utils/hooks/__init__.py index 89c0d44abbb7a..7e9217578c224 100644 --- a/orttraining/orttraining/python/training/utils/hooks/__init__.py +++ b/orttraining/orttraining/python/training/utils/hooks/__init__.py @@ -8,12 +8,14 @@ __all__ = [ "StatisticsSubscriber", + "MemoryStatisticsSubscriber", "GlobalSubscriberManager", "inspect_activation", "ZeROOffloadSubscriber", "configure_ort_compatible_zero_stage3", ] +from ._mem_statistics_subscriber import MemoryStatisticsSubscriber from ._statistics_subscriber import StatisticsSubscriber, _InspectActivation from ._subscriber_manager import SubscriberManager from ._zero_offload_subscriber import ZeROOffloadSubscriber, configure_ort_compatible_zero_stage3 diff --git a/orttraining/orttraining/python/training/utils/hooks/_mem_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_mem_statistics_subscriber.py new file mode 100644 index 0000000000000..e0c522662785d --- /dev/null +++ b/orttraining/orttraining/python/training/utils/hooks/_mem_statistics_subscriber.py @@ -0,0 +1,135 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + + +from typing import List, Optional, Tuple, Union + +import onnx +import torch + +from onnxruntime.training.utils import ( + ORTModelInputOutputType, + extract_data_and_schema, + log_memory_usage, + unflatten_data_using_schema, +) + +from ._subscriber_base import RuntimeStates, SubscriberBase + +_PRE_FW_PASS_PHASE = "pre-fw-pass" +_POST_FW_PASS_PHASE = "post-fw-pass" +_PRE_BW_PASS_PHASE = "pre-bw-pass" +_POST_BW_PASS_PHASE = "post-bw-pass" + + +class _InspectMemoryUsage(torch.autograd.Function): + """This class is used to print the memory statistics in the forward and backward passes.""" + + @staticmethod + def forward( + ctx, phase: str, run_ctx: RuntimeStates, module: torch.nn.Module, *input_tensor_list: Tuple[torch.Tensor, ...] + ) -> Tuple[torch.Tensor, ...]: + """Make sure there is the same number of `tensor` inputs and outputs.""" + ctx.current_step = run_ctx.global_states.execution_step + ctx.phase = phase + ctx.module = module + + assert ctx.phase in [_PRE_FW_PASS_PHASE, _POST_FW_PASS_PHASE], f"Invalid phase {ctx.phase}" + + # The step is not always consistent with the step in users' training loops. + # It is a counter of how many times the forward+backward pass is called. + log_memory_usage(f"{ctx.phase}", rank_0_only=True, step_info=f"step {ctx.current_step}", module=ctx.module) + + return tuple(t.detach().requires_grad_(t.requires_grad) for t in input_tensor_list) + + @staticmethod + def backward(ctx, *grad_output: Tuple[Optional[torch.Tensor], ...]) -> Tuple[Optional[torch.Tensor], ...]: + phase = ctx.phase + if ctx.phase == _PRE_FW_PASS_PHASE: + phase = _POST_BW_PASS_PHASE + elif ctx.phase == _POST_FW_PASS_PHASE: + phase = _PRE_BW_PASS_PHASE + log_memory_usage(f"{phase}", rank_0_only=True, step_info=f"step {ctx.current_step}", module=ctx.module) + return (None, None, None, *tuple(g for g in grad_output)) + + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes + + @staticmethod + def alias_input(node_proto_str: str): + node = onnx.NodeProto() + node.ParseFromString(node_proto_str) + non_tensor_fw_input_count = 3 + fw_output_count = len(node.output) - 1 # exclude the first output appended in ONNX + fw_alias_map = [-1] * fw_output_count + bw_alias_map = [-1] * (non_tensor_fw_input_count + len(node.input)) + + for i in range(fw_output_count): + fw_alias_map[i] = i + non_tensor_fw_input_count + + tensor_input_index = 0 + for i in range(len(bw_alias_map)): + if i < non_tensor_fw_input_count: + continue + bw_alias_map[i] = tensor_input_index + tensor_input_index += 1 + return fw_alias_map, bw_alias_map + + +class MemoryStatisticsSubscriber(SubscriberBase): + """ + This subscriber is used to print the memory statistics in the forward and backward passes. + """ + + def __init__( + self, + start_step: Union[None, int] = None, + end_step: Union[None, int] = None, + ): + """ + Steps in [start_step, end_step) will run subscriber actions. + + Args: + start_step: the first step that runs subscriber actions. + end_step: the end step (exclusively) that runs subscriber actions. + """ + super().__init__(start_step=start_step, end_step=end_step) + + def pre_forward_outmost_module_apply_impl( + self, + run_ctx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + flatten_args_tensor_list, args_schema = extract_data_and_schema(args) + flatten_kwargs_tensor_list, kwargs_schema = extract_data_and_schema(kwargs) + flatten_out = _InspectMemoryUsage.apply( + _PRE_FW_PASS_PHASE, run_ctx, module, *(flatten_args_tensor_list + flatten_kwargs_tensor_list) + ) + args_tensors = flatten_out[: len(flatten_args_tensor_list)] + kwargs_tensors = flatten_out[len(flatten_args_tensor_list) :] + restored_args = unflatten_data_using_schema(args_tensors, args_schema) + restored_kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema) + + return restored_args, restored_kwargs + + def post_forward_outmost_module_apply_impl( + self, + run_ctx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + flatten_output_tensor_list, output_schema = extract_data_and_schema(outputs) + output_tensors = _InspectMemoryUsage.apply(_POST_FW_PASS_PHASE, run_ctx, module, *flatten_output_tensor_list) + restored_outputs = unflatten_data_using_schema(output_tensors, output_schema) + + return args, restored_outputs diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index c5be17236ac06..35f8ada6507fe 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -156,12 +156,12 @@ def __init__( ) def post_forward_tensor_apply_impl( - self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + self, run_ctx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor ) -> torch.Tensor: - module_index = run_rtx.global_states.module_to_module_index[module] + module_index = run_ctx.global_states.module_to_module_index[module] name = f"{module.__class__.__name__}_{module_index}_{tensor_index}th_output" return _InspectActivation.apply( - name, module_index, run_rtx, tensor, self.module_post_forward_impl, self.module_pre_backward_impl + name, module_index, run_ctx, tensor, self.module_post_forward_impl, self.module_pre_backward_impl ) def module_post_forward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int): diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py index 1b9a6fc91ec3c..59286d2c8f9d7 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py @@ -62,7 +62,7 @@ def __init__(self, start_step: Optional[int], end_step: Optional[int]): def pre_forward_module_apply( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, kwargs: ORTModelInputOutputType, @@ -70,7 +70,7 @@ def pre_forward_module_apply( """This function is called inside the nn.Module's pre-forward hook. Args: - run_rtx (RuntimeStates): The runtime states of SubscriberManager. + run_ctx (RuntimeStates): The runtime states of SubscriberManager. module (torch.nn.Module): The module that is being executed. args (ORTModelInputOutputType): The positional arguments that are passed to the module's pre-forward hook. kwargs (ORTModelInputOutputType): The keyword arguments that are passed to the module's pre-forward hook. @@ -79,15 +79,15 @@ def pre_forward_module_apply( Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated args and kwargs. """ - if self._need_skip_step(run_rtx.global_states.execution_step): + if self._need_skip_step(run_ctx.global_states.execution_step): return args, kwargs - updated_args, updated_kwargs = self.pre_forward_module_apply_impl(run_rtx, module, args, kwargs) + updated_args, updated_kwargs = self.pre_forward_module_apply_impl(run_ctx, module, args, kwargs) return updated_args, updated_kwargs def pre_forward_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, kwargs: ORTModelInputOutputType, @@ -95,29 +95,29 @@ def pre_forward_module_apply_impl( return args, kwargs def pre_forward_tensor_apply( - self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + self, run_ctx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor ) -> torch.Tensor: """This function is called inside the nn.Module's pre-forward hook. Args: - run_rtx (RuntimeStates): The runtime states of SubscriberManager. + run_ctx (RuntimeStates): The runtime states of SubscriberManager. module (torch.nn.Module): The module that is being executed. tensor_index (int): The index of the tensor in the input tensor list. tensor (torch.Tensor): The tensor is one of module's forward inputs. """ - if self._need_skip_step(run_rtx.global_states.execution_step): + if self._need_skip_step(run_ctx.global_states.execution_step): return tensor - return self.pre_forward_tensor_apply_impl(run_rtx, module, tensor_index, tensor) + return self.pre_forward_tensor_apply_impl(run_ctx, module, tensor_index, tensor) def pre_forward_tensor_apply_impl( - self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + self, run_ctx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor ) -> torch.Tensor: return tensor def post_forward_module_apply( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, @@ -125,7 +125,7 @@ def post_forward_module_apply( """This function is called inside the nn.Module's post-forward hook. Args: - run_rtx (RuntimeStates): The runtime states of SubscriberManager. + run_ctx (RuntimeStates): The runtime states of SubscriberManager. module (torch.nn.Module): The module that is being executed. args (ORTModelInputOutputType): The inputs arguments that are passed to the module's post-forward hook as input. @@ -135,14 +135,14 @@ def post_forward_module_apply( Returns: Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated inputs and outputs. """ - if self._need_skip_step(run_rtx.global_states.execution_step): + if self._need_skip_step(run_ctx.global_states.execution_step): return args, outputs - return self.post_forward_module_apply_impl(run_rtx, module, args, outputs) + return self.post_forward_module_apply_impl(run_ctx, module, args, outputs) def post_forward_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, @@ -150,12 +150,12 @@ def post_forward_module_apply_impl( return args, outputs def post_forward_tensor_apply( - self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + self, run_ctx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor ) -> torch.Tensor: """This function is called inside the nn.Module's post-forward hook. Args: - run_rtx (RuntimeStates): The runtime states of SubscriberManager. + run_ctx (RuntimeStates): The runtime states of SubscriberManager. module (torch.nn.Module): The module that is being executed. tensor_index (int): The index of the tensor in the output tensor list. tensor (torch.Tensor): The tensor is one of module's forward outputs. @@ -163,19 +163,53 @@ def post_forward_tensor_apply( Returns: torch.Tensor: Updated tensor. """ - if self._need_skip_step(run_rtx.global_states.execution_step): + if self._need_skip_step(run_ctx.global_states.execution_step): return tensor - return self.post_forward_tensor_apply_impl(run_rtx, module, tensor_index, tensor) + return self.post_forward_tensor_apply_impl(run_ctx, module, tensor_index, tensor) def post_forward_tensor_apply_impl( - self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + self, run_ctx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor ) -> torch.Tensor: return tensor + def pre_forward_outmost_module_apply( + self, + run_ctx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is called inside the nn.Module's pre-forward hook. + + Args: + run_ctx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + args (ORTModelInputOutputType): The positional arguments that are passed to the module's pre-forward hook. + kwargs (ORTModelInputOutputType): The keyword arguments that are passed to the module's pre-forward hook. + + Returns: + Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated args and kwargs. + + """ + if self._need_skip_step(run_ctx.global_states.execution_step): + return args, kwargs + + updated_args, updated_kwargs = self.pre_forward_outmost_module_apply_impl(run_ctx, module, args, kwargs) + return updated_args, updated_kwargs + + def pre_forward_outmost_module_apply_impl( + self, + run_ctx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + return args, kwargs + def post_forward_outmost_module_apply( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, @@ -183,7 +217,7 @@ def post_forward_outmost_module_apply( """This function is called inside the outmost nn.Module's post-forward hook. Args: - run_rtx (RuntimeStates): The runtime states of SubscriberManager. + run_ctx (RuntimeStates): The runtime states of SubscriberManager. module (torch.nn.Module): The module that is being executed. args (ORTModelInputOutputType): The inputs arguments that are passed to the module's post-forward hook as input. @@ -193,14 +227,14 @@ def post_forward_outmost_module_apply( Returns: Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated inputs and outputs. """ - if self._need_skip_step(run_rtx.global_states.execution_step): + if self._need_skip_step(run_ctx.global_states.execution_step): return args, outputs - return self.post_forward_outmost_module_apply_impl(run_rtx, module, args, outputs) + return self.post_forward_outmost_module_apply_impl(run_ctx, module, args, outputs) def post_forward_outmost_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index c9c06dabab4de..1656d4eac3d7c 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -155,6 +155,9 @@ def _initialize(self, module: torch.nn.Module): raise RuntimeError("No subscribers are registered.") def _pre_forward_outmost_module_hook(module, module_inputs): + return _pre_forward_outmost_module_with_kwarg_hook(module, module_inputs, {}) + + def _pre_forward_outmost_module_with_kwarg_hook(module, module_inputs, kwargs): # This check is to support the case where module is first registered in the subscriber manager, # then the module and hook are copied, when new module instance runs to the hook, the global states # are not reset, so the logic depends on the global states will fail. So in the outer-most pre-forward hook @@ -168,9 +171,20 @@ def _pre_forward_outmost_module_hook(module, module_inputs): "Initialize global states for the first time, this should only happen once for each outmost module." ) self._initialize_one_time_global_states(module) + + # Call pre outmost module forward custom actions for subscribers + for sub in self._subscribers: + module_inputs = sub.pre_forward_outmost_module_apply(self._run_ctx, module, module_inputs, kwargs) + return module_inputs - module.register_forward_pre_hook(_pre_forward_outmost_module_hook) + # "with_kwargs" is not available for low versions of PyTorch. + if "with_kwargs" in inspect.signature(module.register_forward_pre_hook).parameters: + self._pre_forward_hooks.append( + module.register_forward_pre_hook(_pre_forward_outmost_module_with_kwarg_hook, with_kwargs=True) + ) + else: + self._pre_forward_hooks.append(module.register_forward_pre_hook(_pre_forward_outmost_module_hook)) next_module_index = [0] self._register_hooks_recursively(module, 1, next_module_index) @@ -189,7 +203,7 @@ def _post_forward_outmost_module_hook(module, module_inputs, module_outputs): return restored_outputs - module.register_forward_hook(_post_forward_outmost_module_hook) + self._pre_forward_hooks.append(module.register_forward_hook(_post_forward_outmost_module_hook)) def _initialize_one_time_global_states(self, module: torch.nn.Module): def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: List[int]): @@ -244,6 +258,18 @@ def _pre_forward_module_with_kwargs_hook(module, module_inputs, kwargs): for sub in self._subscribers: module_inputs, kwargs = sub.pre_forward_module_apply(self._run_ctx, module, module_inputs, kwargs) + if len(self._subscribers) == 0: + return module_inputs, kwargs + + # If there is no tensor level post forward func override, we can skip the following tensor level hook. + if all( + [ + sub.__class__.pre_forward_tensor_apply_impl == SubscriberBase.pre_forward_tensor_apply_impl + for sub in self._subscribers + ] + ): + return module_inputs, kwargs + # Tensor level hook flatten_positional_input_tensor_list, input_schema = extract_data_and_schema(module_inputs) flatten_keyword_input_tensor_list, keyword_input_schema = extract_data_and_schema(kwargs) @@ -272,6 +298,18 @@ def _post_forward_module_hook(module, module_inputs, module_outputs): for sub in self._subscribers: _, module_outputs = sub.post_forward_module_apply(self._run_ctx, module, module_inputs, module_outputs) + if len(self._subscribers) == 0: + return module_outputs + + # If there is no tensor level post forward func override, we can skip the following tensor level hook. + if all( + [ + sub.__class__.post_forward_tensor_apply_impl == SubscriberBase.post_forward_tensor_apply_impl + for sub in self._subscribers + ] + ): + return module_outputs + # Tensor level hook flatten_output_tensor_list, output_schema = extract_data_and_schema(module_outputs) for sub in self._subscribers: diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index e6004319ef5ea..6b318d333fade 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -5,7 +5,6 @@ import ctypes import inspect -import warnings from collections import OrderedDict from datetime import timedelta from types import CodeType, FunctionType @@ -169,8 +168,7 @@ def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="./", sta if torch.nn.functional.linear is zero3_linear_wrap: torch.nn.functional.linear = _zero3_linear_wrap_ort_compatible -except ImportError as e: - warnings.warn(f"DeepSpeed import error {e}") +except ImportError: def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, stats_overwrite=False): raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.") @@ -476,7 +474,7 @@ def __init__(self, offloader, one_time_init: _ZeROOffloadOneTimeInitializer, ena @nvtx_function_decorator def pre_forward_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, kwargs: ORTModelInputOutputType, @@ -552,7 +550,7 @@ def _wrap_pre_forward_module_hook(module): @nvtx_function_decorator def post_forward_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, @@ -611,7 +609,7 @@ def _wrap_post_forward_module_hook(module, input, outputs): @nvtx_function_decorator def post_forward_outmost_module_apply_impl( self, - run_rtx: RuntimeStates, + run_ctx: RuntimeStates, module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, diff --git a/orttraining/orttraining/python/training/utils/torch_profile_utils.py b/orttraining/orttraining/python/training/utils/torch_profile_utils.py index 382d7dac142fe..b7a0976967e30 100644 --- a/orttraining/orttraining/python/training/utils/torch_profile_utils.py +++ b/orttraining/orttraining/python/training/utils/torch_profile_utils.py @@ -3,6 +3,10 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + +import subprocess as sp + import torch @@ -26,3 +30,88 @@ def wrapped_fn(*args, **kwargs): return ret_val return wrapped_fn + + +def log_memory_usage(cur_phase: str, rank_0_only=True, step_info="", logger=None, module=None): + """Log memory usage for the current phase. + Args: + cur_phase (str): The current phase. + rank_0_only (bool, optional): Only log the memory usage for rank 0. Defaults to True. + step_info (str, optional): The step information. Defaults to "". + logger (logging.Logger, optional): The logger to log the memory usage. Defaults to None, which means print to stdout. + module (torch.nn.Module, optional): The module to get parameter, buffer and grad sizes. Defaults to None. + """ + rank = 0 + if rank_0_only is True: + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + if rank != 0: + return + + _normalizer_factor = float(1024 * 1024) + _normalizer_unit = "MiB" + + def _normalize(mem_size_in_bytes: float | int) -> str: + return f"{float(mem_size_in_bytes) / _normalizer_factor:.0f}" + + cur_mem_allocated = _normalize(torch.cuda.memory_allocated()) + max_mem_allocated = _normalize(torch.cuda.max_memory_allocated()) + cur_mem_cached = _normalize(torch.cuda.memory_reserved()) + max_mem_cached = _normalize(torch.cuda.max_memory_reserved()) + torch_mem_stat = torch.cuda.memory_stats() + cur_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) + max_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) + + def output_to_list(x): + return x.decode("ascii").split("\n")[:-1] + + nvm_cmd = "nvidia-smi --query-gpu=memory.used --format=csv" + try: + memory_use_info = output_to_list(sp.check_output(nvm_cmd.split(), stderr=sp.STDOUT))[1:] + except sp.CalledProcessError as e: + raise RuntimeError(f"command '{e.cmd}' return with error (code {e.returncode}): {e.output}") from None + memory_use_value = [int(x.split()[0]) for i, x in enumerate(memory_use_info)][rank] + + mem_stats = [ + ["phase", cur_phase], + ["nvm smi", memory_use_value], + ["allocated", cur_mem_allocated], # current memory allocated for tensors + ["max allocated", max_mem_allocated], # peak memory allocated for tensors + ["cached", cur_mem_cached], # current memory cached for the caching allocator + ["max cached", max_mem_cached], # peak memory cached for caching allocator. + ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory + ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory + ] + + # Calculate the total size of parameters and gradients in the model + if module: + param_total_size = 0 + grad_total_size = 0 + for p in module.parameters(): + if p.is_cuda: + param_total_size += p.numel() * p.element_size() + if p.grad is not None and p.grad.is_cuda: + grad_total_size += p.grad.numel() * p.grad.element_size() + + # Calculate the total size of buffers in the model + buffer_total_size = 0 + for b in module.buffers(): + if b.is_cuda: + buffer_total_size += b.numel() * b.element_size() + + mem_stats.extend( + [ + ["param size", _normalize(param_total_size)], + ["grad size", _normalize(grad_total_size)], + ["buffer size", _normalize(buffer_total_size)], + ] + ) + + summ = f"rank-{rank} {step_info} memory ({_normalizer_unit})" + for stat in mem_stats: + summ += f" | {stat[0]}: {stat[1]}" + + if logger is None: + print(summ) + else: + logger.info(summ)