From 0a36e1dc698c625b0c70eeedcd0b561b016da7b1 Mon Sep 17 00:00:00 2001 From: pengwa Date: Tue, 26 Mar 2024 21:25:59 +0800 Subject: [PATCH] Fix memory stats printing (#20061) ### Fix memory stats printing The mmeory stats printing is failed when module is in eval mode, doing ORTModule wrap. At that time, runtime inspector for training manager should have training model being true, but got a false (because existing logic get the boolean from module.training). Runtime inspector as part of training manager or inference manager should know it is serving training or not explicitly, so we cannot depend on the stat of module.training during ORTModule initialization. ### Motivation and Context --- .../ortmodule/_graph_execution_manager.py | 24 +++++++++---- .../training/ortmodule/_inference_manager.py | 3 +- .../training/ortmodule/_runtime_inspector.py | 27 +++++++++++--- .../training/ortmodule/_training_manager.py | 4 +-- .../python/orttraining_test_ortmodule_api.py | 36 +++++++++++++++++++ 5 files changed, 78 insertions(+), 16 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 568c92b71277f..5123594bff387 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -54,10 +54,20 @@ def __init__( self, module: _FlattenedModule, debug_options: DebugOptions, + export_mode: int, fallback_manager: _FallbackManager, logger: logging.Logger, ): - """Manages construction and execution of ONNX graphs""" + """Manages construction and execution of ONNX graphs. + + Args: + module: The flatten PyTorch module to be executed. + debug_options: Debug options for ORTModule. + export_mode: export mode, should be torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL. + fallback_manager: Fallback manager to handle exceptions. + logger: Logger for ORTModule. + + """ super().__init__(module._original_module) @@ -88,16 +98,12 @@ def __init__( self._first_skip_check_warning = True - # Inspector for runtime information, for example input data, memory usage, etc. - self._runtime_inspector = RuntimeInspector(self._logger, self._original_module) - self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step) - # Tracker for ORTModule model export, session creation overhead. self.time_tracker = _logger.TimeTracker() # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL # To be instantiated in the concrete implementation of GraphExecutionManager - self._export_mode = None + self._export_mode = export_mode # Exporter can take extra arguments for ORTModule extensions # It cannot overlap with required/immutable arguments (validated in runtime) @@ -129,6 +135,12 @@ def __init__( # Re-export will be avoided if _skip_check is enabled. self._original_model_has_changed = False + # Inspector for runtime information, for example input data, memory usage, etc. + self._runtime_inspector = RuntimeInspector( + self._logger, self._original_module, self._export_mode == torch.onnx.TrainingMode.TRAINING + ) + self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step) + # Load ATen operator executor extension. load_aten_op_executor_cpp_extension() diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 6690af9b71bf1..13145c7c79091 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -28,8 +28,7 @@ class InferenceManager(GraphExecutionManager): """ def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger): - super().__init__(model, debug_options, fallback_manager, logger) - self._export_mode = torch.onnx.TrainingMode.EVAL + super().__init__(model, debug_options, torch.onnx.TrainingMode.EVAL, fallback_manager, logger) @staticmethod def execution_session_run_forward( diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index d3fe132609a90..5c86070430e81 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -46,11 +46,18 @@ class RuntimeInspector: Runtime inspector for ORTModule. """ - def __init__(self, logger: Logger, module: torch.nn.Module): + def __init__(self, logger: Logger, module: torch.nn.Module, training: bool): + """Initialize runtime inspector. + + Args: + logger: Logger. + module: Torch module. + training: a boolean indicating whether the module is in training mode. + """ self._logger = logger self.input_density_ob: Union[InputDensityObserver, None] = None - self.memory_ob = MemoryObserver(module, self._logger) + self.memory_ob = MemoryObserver(module, self._logger, training) def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None: """Initialize input inspector from the given ONNX model and user input names. @@ -479,7 +486,14 @@ class MemoryObserver: NORMALIZER_FACTOR = float(1024 * 1024) NORMALIZER_UNIT = "MiB" - def __init__(self, m: torch.nn.Module, logger: Logger): + def __init__(self, m: torch.nn.Module, logger: Logger, training: bool): + """Initialize memory observer. + + Args: + m: Torch module. + logger: Logger. + training: a boolean indicating whether the module is in training mode. + """ self._logger = logger self._is_enabled = True @@ -503,7 +517,10 @@ def __init__(self, m: torch.nn.Module, logger: Logger): self._rank_info = f"[{self._rank}/{self._world_size}]" self._pre_phase = Phase.INVALID - self._last_phase = Phase.POST_BACKWARD if m.training else Phase.POST_FORWARD + + # Cannot infer it is for training or inferencing purpose from module.training, + # because it probabbly is not set correctly when this happens. + self._last_phase = Phase.POST_BACKWARD if training else Phase.POST_FORWARD self._is_first_inspect = True @@ -721,7 +738,7 @@ def _get_user_config_without_freq(configs: str): notes.append(saving_recommendation) saving_recommendation = ( - "[Memory Optimizer] memory saving is calculated based on the 1st batch symbolic dim values:\n" + "[Memory Optimizer] Memory saving is calculated based on the 1st batch symbolic dim values:\n" ) for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items(): saving_recommendation += f" {dim_param}={dim_value}," diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 5fa332d12f01c..a7426bce38a40 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -38,9 +38,7 @@ def __init__( fallback_manager: _FallbackManager, logger: Logger, ): - super().__init__(model, debug_options, fallback_manager, logger) - - self._export_mode = torch.onnx.TrainingMode.TRAINING + super().__init__(model, debug_options, torch.onnx.TrainingMode.TRAINING, fallback_manager, logger) self._forward_class = self._create_autofunction_class() @staticmethod diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index d6f55e787c320..da217eb76949c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6499,6 +6499,42 @@ def run_step(model, x, y, z): os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val +def test_bert_memory_inspection(caplog): + original_val = os.environ.get("ORTMODULE_PRINT_MEMORY_STATS", None) + + # Create PyTorch model with dropout disabled. + pt_model = _get_bert_for_sequence_classification_model( + "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0 + ) + + os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = "1" + pt_model.eval() # Put it in evaluate mode by intention, in case some initialization in ORTModule use the module.is_training for its checks by mistake. + ort_model = ORTModule( + copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO) # The logged memory info is in INFO level. + ) + + def run_step(model, x, y, z): + outputs = model(x, y, None, None, None, None, z) + loss = outputs[0] + loss.backward() + + ort_model.train() + for _ in range(32): + x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda") + run_step(ort_model, x, y, z) + + info_records = [ + record.message for record in caplog.records if record.levelname == "INFO" and "(MiB) | phase:" in record.message + ] + + assert len(info_records) == 4 * 11 + + # Make sure environment variable is restored to its original value after the run is completed. + torch.cuda.synchronize() + if original_val is not None: + os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = original_val + + @pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32]) def test_overridden_softmax_export(softmax_compute_type): class CustomSoftmaxExportTest(torch.nn.Module):