Skip to content

Commit

Permalink
refinement
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Dec 21, 2023
1 parent cc31965 commit 856da38
Show file tree
Hide file tree
Showing 11 changed files with 317 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
import torch.utils.checkpoint
from onnx import ModelProto
from onnx import ModelProto, helper
from packaging import version
from torch.onnx import symbolic_helper

Expand Down Expand Up @@ -393,6 +393,31 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model
node.name = f"{op_name_prefix}_id_{index}"
index += 1

from onnxruntime.training.utils.hooks._mem_statistics_subscriber import _InspectMemoryUsage
from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation
from onnxruntime.training.utils.hooks._subscriber_manager import _IncrementStep

_allowed_unsafe_run_python_op_names = [
get_fully_qualified_class_name(_InspectMemoryUsage),
get_fully_qualified_class_name(_IncrementStep),
get_fully_qualified_class_name(_InspectActivation),
]

for node in exported_model.graph.node:
if node.op_type == "PythonOp":
func_name = None
safe_run_mode_attr = None
for attr in node.attribute:
if attr.name == "func_name":
func_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
if attr.name == "safe_run_mode":
safe_run_mode_attr = attr

if func_name in _allowed_unsafe_run_python_op_names:
if safe_run_mode_attr:
node.attribute.remove(safe_run_mode_attr)
node.attribute.append(helper.make_attribute("safe_run_mode", 0))

return exported_model


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,8 @@ def __init__(self, m: torch.nn.Module, logger: Logger):

self._is_first_inspect = True

self._m = m

def is_enabled(self) -> bool:
"""Check if memory inspector is enabled."""
return self._is_enabled
Expand Down Expand Up @@ -621,10 +623,12 @@ def inspect_memory(self, cur_phase: Phase):
need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0)

if need_print:
self._logger.info(
log_memory_usage(
_convert_phase_to_string(cur_phase), rank_0_only=True, step_info=f"step {self._current_step}"
)
log_memory_usage(
_convert_phase_to_string(cur_phase),
rank_0_only=True,
step_info=f"step {self._current_step}",
logger=self._logger,
module=self._m,
)

if cur_phase == self._last_phase:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def forward(self, *inputs, **kwargs):
# Only change this after the firs time a warning is issued.
self._first_skip_check_warning = False
self._logger.info(
"Fast path enabled - skipping checks.Rebuild graph: %s, Execution agent: %s, Device check: %s",
"Fast path enabled - skipping checks. Rebuild graph: %s, Execution agent: %s, Device check: %s",
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,13 @@ def _get_func_name(node: NodeProto) -> Optional[str]:
exported_model.graph.node.insert(0, weight_pull_node)

# Update safe_run_mode attribute for PythonOp.
from onnxruntime.training.utils.hooks._subscriber_manager import _IncrementStep

_allowed_unsafe_run_python_op_names = [
get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction),
get_fully_qualified_class_name(ORTZeROOffloadPostForwardFunction),
func_full_qual_name,
DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME,
DEEPSPEED_POST_BACKWARD_FUNCTION_NAME,
DEEPSPEED_LINEAR_FUNCTION_NAME,
get_fully_qualified_class_name(_IncrementStep),
]

for node in exported_model.graph.node:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

__all__ = [
"StatisticsSubscriber",
"MemoryStatisticsSubscriber",
"GlobalSubscriberManager",
"inspect_activation",
"ZeROOffloadSubscriber",
"configure_ort_compatible_zero_stage3",
]

from ._mem_statistics_subscriber import MemoryStatisticsSubscriber
from ._statistics_subscriber import StatisticsSubscriber, _InspectActivation
from ._subscriber_manager import SubscriberManager
from ._zero_offload_subscriber import ZeROOffloadSubscriber, configure_ort_compatible_zero_stage3
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# -------------------------------------------------------------------------

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning

Run lintrunner -a to apply this patch.
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------


from typing import List, Optional, Tuple, Union

import onnx
import torch

from onnxruntime.training.utils import log_memory_usage, extract_data_and_schema, unflatten_data_using_schema, ORTModelInputOutputType


from ._subscriber_base import RuntimeStates, SubscriberBase


_PRE_FW_PASS_PHASE = "pre-fw-pass"
_POST_FW_PASS_PHASE = "post-fw-pass"
_PRE_BW_PASS_PHASE = "pre-bw-pass"
_POST_BW_PASS_PHASE = "post-bw-pass"

class _InspectMemoryUsage(torch.autograd.Function):
"""This class is used to print the memory statistics in the forward and backward passes."""

@staticmethod
def forward(ctx, phase: str, run_ctx: RuntimeStates, module: torch.nn.Module,
*input_tensor_list: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
"""Make sure there is the same number of `tensor` inputs and outputs.
"""
ctx.current_step = run_ctx.global_states.execution_step
ctx.phase = phase
ctx.module = module

assert ctx.phase in [_PRE_FW_PASS_PHASE, _POST_FW_PASS_PHASE], f"Invalid phase {ctx.phase}"

# The step is not always consistent with the step in users' training loops.
# It is a counter of how many times the forward+backward pass is called.
log_memory_usage(f"{ctx.phase}", rank_0_only=True, step_info=f"step {ctx.current_step}", module=ctx.module)

return tuple(t.detach().requires_grad_(t.requires_grad) for t in input_tensor_list)

@staticmethod
def backward(ctx, *grad_output: Tuple[Optional[torch.Tensor], ...]) -> Tuple[Optional[torch.Tensor], ...]:
phase = ctx.phase
if ctx.phase == _PRE_FW_PASS_PHASE:
phase = _POST_BW_PASS_PHASE
elif ctx.phase == _POST_FW_PASS_PHASE:
phase = _PRE_BW_PASS_PHASE
log_memory_usage(f"{phase}", rank_0_only=True, step_info=f"step {ctx.current_step}", module=ctx.module)
return (None, None, None, *tuple(g for g in grad_output))

@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
return tensor_input_shapes, tensor_input_dtypes

@staticmethod
def alias_input(node_proto_str: str):
node = onnx.NodeProto()
node.ParseFromString(node_proto_str)
non_tensor_fw_input_count = 3
fw_output_count = len(node.output) - 1 # exclude the first output appended in ONNX
fw_alias_map = [-1] * fw_output_count
bw_alias_map = [-1] * (non_tensor_fw_input_count + len(node.input))

for i in range(fw_output_count):
fw_alias_map[i] = i + non_tensor_fw_input_count

tensor_input_index = 0
for i in range(len(bw_alias_map)):
if i < non_tensor_fw_input_count:
continue
bw_alias_map[i] = tensor_input_index
tensor_input_index += 1
return fw_alias_map, bw_alias_map



class MemoryStatisticsSubscriber(SubscriberBase):
"""
This subscriber is used to print the memory statistics in the forward and backward passes.
"""

def __init__(
self,
start_step: Union[None, int] = None,
end_step: Union[None, int] = None,
):
"""
Steps in [start_step, end_step) will run subscriber actions.
Args:
start_step: the first step that runs subscriber actions.
end_step: the end step (exclusively) that runs subscriber actions.
"""
super().__init__(start_step=start_step, end_step=end_step)

def pre_forward_outmost_module_apply_impl(
self,
run_ctx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
kwargs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:

flatten_args_tensor_list, args_schema = extract_data_and_schema(args)
flatten_kwargs_tensor_list, kwargs_schema = extract_data_and_schema(kwargs)
flatten_out = _InspectMemoryUsage.apply(_PRE_FW_PASS_PHASE, run_ctx, module,
*(flatten_args_tensor_list + flatten_kwargs_tensor_list))
args_tensors = flatten_out[:len(flatten_args_tensor_list)]
kwargs_tensors = flatten_out[len(flatten_args_tensor_list):]
restored_args = unflatten_data_using_schema(args_tensors, args_schema)
restored_kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema)

return restored_args, restored_kwargs


def post_forward_outmost_module_apply_impl(
self,
run_ctx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
outputs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:

flatten_output_tensor_list, output_schema = extract_data_and_schema(outputs)
output_tensors = _InspectMemoryUsage.apply(_POST_FW_PASS_PHASE, run_ctx, module, *flatten_output_tensor_list)
restored_outputs = unflatten_data_using_schema(output_tensors, output_schema)

return args, restored_outputs
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ def __init__(
)

def post_forward_tensor_apply_impl(
self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor
self, run_ctx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor
) -> torch.Tensor:
module_index = run_rtx.global_states.module_to_module_index[module]
module_index = run_ctx.global_states.module_to_module_index[module]
name = f"{module.__class__.__name__}_{module_index}_{tensor_index}th_output"
return _InspectActivation.apply(
name, module_index, run_rtx, tensor, self.module_post_forward_impl, self.module_pre_backward_impl
name, module_index, run_ctx, tensor, self.module_post_forward_impl, self.module_pre_backward_impl
)

def module_post_forward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int):
Expand Down
Loading

0 comments on commit 856da38

Please sign in to comment.