From 17811c4db01dc0bcd0a09057e53f2d65422323f8 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Mon, 25 Mar 2024 12:41:17 +0000 Subject: [PATCH] fix memory stats printing --- .../ortmodule/_graph_execution_manager.py | 13 ++++--- .../training/ortmodule/_inference_manager.py | 3 +- .../training/ortmodule/_runtime_inspector.py | 13 ++++--- .../training/ortmodule/_training_manager.py | 4 +-- .../python/orttraining_test_ortmodule_api.py | 36 +++++++++++++++++++ 5 files changed, 54 insertions(+), 15 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 568c92b71277f..2a0c6805f44b5 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -54,6 +54,7 @@ def __init__( self, module: _FlattenedModule, debug_options: DebugOptions, + export_mode: int, fallback_manager: _FallbackManager, logger: logging.Logger, ): @@ -88,16 +89,12 @@ def __init__( self._first_skip_check_warning = True - # Inspector for runtime information, for example input data, memory usage, etc. - self._runtime_inspector = RuntimeInspector(self._logger, self._original_module) - self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step) - # Tracker for ORTModule model export, session creation overhead. self.time_tracker = _logger.TimeTracker() # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL # To be instantiated in the concrete implementation of GraphExecutionManager - self._export_mode = None + self._export_mode = export_mode # Exporter can take extra arguments for ORTModule extensions # It cannot overlap with required/immutable arguments (validated in runtime) @@ -129,6 +126,12 @@ def __init__( # Re-export will be avoided if _skip_check is enabled. self._original_model_has_changed = False + # Inspector for runtime information, for example input data, memory usage, etc. + self._runtime_inspector = RuntimeInspector( + self._logger, self._original_module, self._export_mode == torch.onnx.TrainingMode.TRAINING + ) + self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step) + # Load ATen operator executor extension. load_aten_op_executor_cpp_extension() diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 6690af9b71bf1..13145c7c79091 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -28,8 +28,7 @@ class InferenceManager(GraphExecutionManager): """ def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger): - super().__init__(model, debug_options, fallback_manager, logger) - self._export_mode = torch.onnx.TrainingMode.EVAL + super().__init__(model, debug_options, torch.onnx.TrainingMode.EVAL, fallback_manager, logger) @staticmethod def execution_session_run_forward( diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index d3fe132609a90..6d9c9fd5c66e4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -46,11 +46,11 @@ class RuntimeInspector: Runtime inspector for ORTModule. """ - def __init__(self, logger: Logger, module: torch.nn.Module): + def __init__(self, logger: Logger, module: torch.nn.Module, training: bool): self._logger = logger self.input_density_ob: Union[InputDensityObserver, None] = None - self.memory_ob = MemoryObserver(module, self._logger) + self.memory_ob = MemoryObserver(module, self._logger, training) def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None: """Initialize input inspector from the given ONNX model and user input names. @@ -479,7 +479,7 @@ class MemoryObserver: NORMALIZER_FACTOR = float(1024 * 1024) NORMALIZER_UNIT = "MiB" - def __init__(self, m: torch.nn.Module, logger: Logger): + def __init__(self, m: torch.nn.Module, logger: Logger, training: bool): self._logger = logger self._is_enabled = True @@ -503,7 +503,10 @@ def __init__(self, m: torch.nn.Module, logger: Logger): self._rank_info = f"[{self._rank}/{self._world_size}]" self._pre_phase = Phase.INVALID - self._last_phase = Phase.POST_BACKWARD if m.training else Phase.POST_FORWARD + + # Cannot infer it is for training or inferencing pupose from module.mode, + # because it probabbly is not set correctly when this happens. + self._last_phase = Phase.POST_BACKWARD if training else Phase.POST_FORWARD self._is_first_inspect = True @@ -721,7 +724,7 @@ def _get_user_config_without_freq(configs: str): notes.append(saving_recommendation) saving_recommendation = ( - "[Memory Optimizer] memory saving is calculated based on the 1st batch symbolic dim values:\n" + "[Memory Optimizer] Memory saving is calculated based on the 1st batch symbolic dim values:\n" ) for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items(): saving_recommendation += f" {dim_param}={dim_value}," diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 5fa332d12f01c..a7426bce38a40 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -38,9 +38,7 @@ def __init__( fallback_manager: _FallbackManager, logger: Logger, ): - super().__init__(model, debug_options, fallback_manager, logger) - - self._export_mode = torch.onnx.TrainingMode.TRAINING + super().__init__(model, debug_options, torch.onnx.TrainingMode.TRAINING, fallback_manager, logger) self._forward_class = self._create_autofunction_class() @staticmethod diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7afad9145ed27..ea5f4d26d7a12 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6496,3 +6496,39 @@ def run_step(model, x, y, z): torch.cuda.synchronize() if original_val is not None: os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val + + +def test_bert_memory_inspection(caplog): + original_val = os.environ.get("ORTMODULE_PRINT_MEMORY_STATS", None) + + # Create PyTorch model with dropout disabled. + pt_model = _get_bert_for_sequence_classification_model( + "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0 + ) + + os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = "1" + pt_model.eval() # Put it in evaluate mode by intention, in case some initialization in ORTModule use the module.is_training for its checks by mistake. + ort_model = ORTModule( + copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO) # The logged memory info is in INFO level. + ) + + def run_step(model, x, y, z): + outputs = model(x, y, None, None, None, None, z) + loss = outputs[0] + loss.backward() + + ort_model.train() + for _ in range(32): + x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda") + run_step(ort_model, x, y, z) + + info_records = [ + record.message for record in caplog.records if record.levelname == "INFO" and "(MiB) | phase:" in record.message + ] + + assert len(info_records) == 4 * 11 + + # Make sure environment variable is restored to its original value after the run is completed. + torch.cuda.synchronize() + if original_val is not None: + os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = original_val