diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 568c92b71277f..5123594bff387 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -54,10 +54,20 @@ def __init__( self, module: _FlattenedModule, debug_options: DebugOptions, + export_mode: int, fallback_manager: _FallbackManager, logger: logging.Logger, ): - """Manages construction and execution of ONNX graphs""" + """Manages construction and execution of ONNX graphs. + + Args: + module: The flatten PyTorch module to be executed. + debug_options: Debug options for ORTModule. + export_mode: export mode, should be torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL. + fallback_manager: Fallback manager to handle exceptions. + logger: Logger for ORTModule. + + """ super().__init__(module._original_module) @@ -88,16 +98,12 @@ def __init__( self._first_skip_check_warning = True - # Inspector for runtime information, for example input data, memory usage, etc. - self._runtime_inspector = RuntimeInspector(self._logger, self._original_module) - self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step) - # Tracker for ORTModule model export, session creation overhead. self.time_tracker = _logger.TimeTracker() # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL # To be instantiated in the concrete implementation of GraphExecutionManager - self._export_mode = None + self._export_mode = export_mode # Exporter can take extra arguments for ORTModule extensions # It cannot overlap with required/immutable arguments (validated in runtime) @@ -129,6 +135,12 @@ def __init__( # Re-export will be avoided if _skip_check is enabled. self._original_model_has_changed = False + # Inspector for runtime information, for example input data, memory usage, etc. + self._runtime_inspector = RuntimeInspector( + self._logger, self._original_module, self._export_mode == torch.onnx.TrainingMode.TRAINING + ) + self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step) + # Load ATen operator executor extension. load_aten_op_executor_cpp_extension() diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 6690af9b71bf1..13145c7c79091 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -28,8 +28,7 @@ class InferenceManager(GraphExecutionManager): """ def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger): - super().__init__(model, debug_options, fallback_manager, logger) - self._export_mode = torch.onnx.TrainingMode.EVAL + super().__init__(model, debug_options, torch.onnx.TrainingMode.EVAL, fallback_manager, logger) @staticmethod def execution_session_run_forward( diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index d3fe132609a90..5c86070430e81 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -46,11 +46,18 @@ class RuntimeInspector: Runtime inspector for ORTModule. """ - def __init__(self, logger: Logger, module: torch.nn.Module): + def __init__(self, logger: Logger, module: torch.nn.Module, training: bool): + """Initialize runtime inspector. + + Args: + logger: Logger. + module: Torch module. + training: a boolean indicating whether the module is in training mode. + """ self._logger = logger self.input_density_ob: Union[InputDensityObserver, None] = None - self.memory_ob = MemoryObserver(module, self._logger) + self.memory_ob = MemoryObserver(module, self._logger, training) def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None: """Initialize input inspector from the given ONNX model and user input names. @@ -479,7 +486,14 @@ class MemoryObserver: NORMALIZER_FACTOR = float(1024 * 1024) NORMALIZER_UNIT = "MiB" - def __init__(self, m: torch.nn.Module, logger: Logger): + def __init__(self, m: torch.nn.Module, logger: Logger, training: bool): + """Initialize memory observer. + + Args: + m: Torch module. + logger: Logger. + training: a boolean indicating whether the module is in training mode. + """ self._logger = logger self._is_enabled = True @@ -503,7 +517,10 @@ def __init__(self, m: torch.nn.Module, logger: Logger): self._rank_info = f"[{self._rank}/{self._world_size}]" self._pre_phase = Phase.INVALID - self._last_phase = Phase.POST_BACKWARD if m.training else Phase.POST_FORWARD + + # Cannot infer it is for training or inferencing purpose from module.training, + # because it probabbly is not set correctly when this happens. + self._last_phase = Phase.POST_BACKWARD if training else Phase.POST_FORWARD self._is_first_inspect = True @@ -721,7 +738,7 @@ def _get_user_config_without_freq(configs: str): notes.append(saving_recommendation) saving_recommendation = ( - "[Memory Optimizer] memory saving is calculated based on the 1st batch symbolic dim values:\n" + "[Memory Optimizer] Memory saving is calculated based on the 1st batch symbolic dim values:\n" ) for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items(): saving_recommendation += f" {dim_param}={dim_value}," diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 5fa332d12f01c..a7426bce38a40 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -38,9 +38,7 @@ def __init__( fallback_manager: _FallbackManager, logger: Logger, ): - super().__init__(model, debug_options, fallback_manager, logger) - - self._export_mode = torch.onnx.TrainingMode.TRAINING + super().__init__(model, debug_options, torch.onnx.TrainingMode.TRAINING, fallback_manager, logger) self._forward_class = self._create_autofunction_class() @staticmethod diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index d6f55e787c320..da217eb76949c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6499,6 +6499,42 @@ def run_step(model, x, y, z): os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val +def test_bert_memory_inspection(caplog): + original_val = os.environ.get("ORTMODULE_PRINT_MEMORY_STATS", None) + + # Create PyTorch model with dropout disabled. + pt_model = _get_bert_for_sequence_classification_model( + "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0 + ) + + os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = "1" + pt_model.eval() # Put it in evaluate mode by intention, in case some initialization in ORTModule use the module.is_training for its checks by mistake. + ort_model = ORTModule( + copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO) # The logged memory info is in INFO level. + ) + + def run_step(model, x, y, z): + outputs = model(x, y, None, None, None, None, z) + loss = outputs[0] + loss.backward() + + ort_model.train() + for _ in range(32): + x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda") + run_step(ort_model, x, y, z) + + info_records = [ + record.message for record in caplog.records if record.levelname == "INFO" and "(MiB) | phase:" in record.message + ] + + assert len(info_records) == 4 * 11 + + # Make sure environment variable is restored to its original value after the run is completed. + torch.cuda.synchronize() + if original_val is not None: + os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = original_val + + @pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32]) def test_overridden_softmax_export(softmax_compute_type): class CustomSoftmaxExportTest(torch.nn.Module):