Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Dec 21, 2023
1 parent 856da38 commit e81dc8f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,29 @@
import onnx
import torch

from onnxruntime.training.utils import log_memory_usage, extract_data_and_schema, unflatten_data_using_schema, ORTModelInputOutputType

from onnxruntime.training.utils import (
ORTModelInputOutputType,
extract_data_and_schema,
log_memory_usage,
unflatten_data_using_schema,
)

from ._subscriber_base import RuntimeStates, SubscriberBase


_PRE_FW_PASS_PHASE = "pre-fw-pass"
_POST_FW_PASS_PHASE = "post-fw-pass"
_PRE_BW_PASS_PHASE = "pre-bw-pass"
_POST_BW_PASS_PHASE = "post-bw-pass"


class _InspectMemoryUsage(torch.autograd.Function):
"""This class is used to print the memory statistics in the forward and backward passes."""

@staticmethod
def forward(ctx, phase: str, run_ctx: RuntimeStates, module: torch.nn.Module,
*input_tensor_list: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
"""Make sure there is the same number of `tensor` inputs and outputs.
"""
def forward(
ctx, phase: str, run_ctx: RuntimeStates, module: torch.nn.Module, *input_tensor_list: Tuple[torch.Tensor, ...]
) -> Tuple[torch.Tensor, ...]:
"""Make sure there is the same number of `tensor` inputs and outputs."""
ctx.current_step = run_ctx.global_states.execution_step
ctx.phase = phase
ctx.module = module
Expand Down Expand Up @@ -79,7 +83,6 @@ def alias_input(node_proto_str: str):
return fw_alias_map, bw_alias_map



class MemoryStatisticsSubscriber(SubscriberBase):
"""
This subscriber is used to print the memory statistics in the forward and backward passes.
Expand All @@ -106,27 +109,25 @@ def pre_forward_outmost_module_apply_impl(
args: ORTModelInputOutputType,
kwargs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:

flatten_args_tensor_list, args_schema = extract_data_and_schema(args)
flatten_kwargs_tensor_list, kwargs_schema = extract_data_and_schema(kwargs)
flatten_out = _InspectMemoryUsage.apply(_PRE_FW_PASS_PHASE, run_ctx, module,
*(flatten_args_tensor_list + flatten_kwargs_tensor_list))
args_tensors = flatten_out[:len(flatten_args_tensor_list)]
kwargs_tensors = flatten_out[len(flatten_args_tensor_list):]
flatten_out = _InspectMemoryUsage.apply(
_PRE_FW_PASS_PHASE, run_ctx, module, *(flatten_args_tensor_list + flatten_kwargs_tensor_list)
)
args_tensors = flatten_out[: len(flatten_args_tensor_list)]
kwargs_tensors = flatten_out[len(flatten_args_tensor_list) :]
restored_args = unflatten_data_using_schema(args_tensors, args_schema)
restored_kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema)

return restored_args, restored_kwargs


def post_forward_outmost_module_apply_impl(
self,
run_ctx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
outputs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:

flatten_output_tensor_list, output_schema = extract_data_and_schema(outputs)
output_tensors = _InspectMemoryUsage.apply(_POST_FW_PASS_PHASE, run_ctx, module, *flatten_output_tensor_list)
restored_outputs = unflatten_data_using_schema(output_tensors, output_schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import ctypes
import inspect
import warnings
from collections import OrderedDict
from datetime import timedelta
from types import CodeType, FunctionType
Expand Down Expand Up @@ -169,7 +168,8 @@ def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="./", sta
if torch.nn.functional.linear is zero3_linear_wrap:
torch.nn.functional.linear = _zero3_linear_wrap_ort_compatible

except ImportError as e:
except ImportError:

def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, stats_overwrite=False):
raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.")

Expand Down

0 comments on commit e81dc8f

Please sign in to comment.