From e81dc8f839af8d31257c9067712a9224893534a2 Mon Sep 17 00:00:00 2001 From: Peng Wang Date: Wed, 20 Dec 2023 19:40:29 -0800 Subject: [PATCH] lint --- .../utils/hooks/_mem_statistics_subscriber.py | 31 ++++++++++--------- .../utils/hooks/_zero_offload_subscriber.py | 4 +-- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/orttraining/orttraining/python/training/utils/hooks/_mem_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_mem_statistics_subscriber.py index 9ce06ac503316..e0c522662785d 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_mem_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_mem_statistics_subscriber.py @@ -9,25 +9,29 @@ import onnx import torch -from onnxruntime.training.utils import log_memory_usage, extract_data_and_schema, unflatten_data_using_schema, ORTModelInputOutputType - +from onnxruntime.training.utils import ( + ORTModelInputOutputType, + extract_data_and_schema, + log_memory_usage, + unflatten_data_using_schema, +) from ._subscriber_base import RuntimeStates, SubscriberBase - _PRE_FW_PASS_PHASE = "pre-fw-pass" _POST_FW_PASS_PHASE = "post-fw-pass" _PRE_BW_PASS_PHASE = "pre-bw-pass" _POST_BW_PASS_PHASE = "post-bw-pass" + class _InspectMemoryUsage(torch.autograd.Function): """This class is used to print the memory statistics in the forward and backward passes.""" @staticmethod - def forward(ctx, phase: str, run_ctx: RuntimeStates, module: torch.nn.Module, - *input_tensor_list: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: - """Make sure there is the same number of `tensor` inputs and outputs. - """ + def forward( + ctx, phase: str, run_ctx: RuntimeStates, module: torch.nn.Module, *input_tensor_list: Tuple[torch.Tensor, ...] + ) -> Tuple[torch.Tensor, ...]: + """Make sure there is the same number of `tensor` inputs and outputs.""" ctx.current_step = run_ctx.global_states.execution_step ctx.phase = phase ctx.module = module @@ -79,7 +83,6 @@ def alias_input(node_proto_str: str): return fw_alias_map, bw_alias_map - class MemoryStatisticsSubscriber(SubscriberBase): """ This subscriber is used to print the memory statistics in the forward and backward passes. @@ -106,19 +109,18 @@ def pre_forward_outmost_module_apply_impl( args: ORTModelInputOutputType, kwargs: ORTModelInputOutputType, ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: - flatten_args_tensor_list, args_schema = extract_data_and_schema(args) flatten_kwargs_tensor_list, kwargs_schema = extract_data_and_schema(kwargs) - flatten_out = _InspectMemoryUsage.apply(_PRE_FW_PASS_PHASE, run_ctx, module, - *(flatten_args_tensor_list + flatten_kwargs_tensor_list)) - args_tensors = flatten_out[:len(flatten_args_tensor_list)] - kwargs_tensors = flatten_out[len(flatten_args_tensor_list):] + flatten_out = _InspectMemoryUsage.apply( + _PRE_FW_PASS_PHASE, run_ctx, module, *(flatten_args_tensor_list + flatten_kwargs_tensor_list) + ) + args_tensors = flatten_out[: len(flatten_args_tensor_list)] + kwargs_tensors = flatten_out[len(flatten_args_tensor_list) :] restored_args = unflatten_data_using_schema(args_tensors, args_schema) restored_kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema) return restored_args, restored_kwargs - def post_forward_outmost_module_apply_impl( self, run_ctx: RuntimeStates, @@ -126,7 +128,6 @@ def post_forward_outmost_module_apply_impl( args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: - flatten_output_tensor_list, output_schema = extract_data_and_schema(outputs) output_tensors = _InspectMemoryUsage.apply(_POST_FW_PASS_PHASE, run_ctx, module, *flatten_output_tensor_list) restored_outputs = unflatten_data_using_schema(output_tensors, output_schema) diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index 17008e6f3ce01..6b318d333fade 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -5,7 +5,6 @@ import ctypes import inspect -import warnings from collections import OrderedDict from datetime import timedelta from types import CodeType, FunctionType @@ -169,7 +168,8 @@ def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="./", sta if torch.nn.functional.linear is zero3_linear_wrap: torch.nn.functional.linear = _zero3_linear_wrap_ort_compatible -except ImportError as e: +except ImportError: + def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, stats_overwrite=False): raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.")