Skip to content

Commit

Permalink
fix memory stats printing
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Mar 25, 2024
1 parent 3dfe4a5 commit 17811c4
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
self,
module: _FlattenedModule,
debug_options: DebugOptions,
export_mode: int,
fallback_manager: _FallbackManager,
logger: logging.Logger,
):
Expand Down Expand Up @@ -88,16 +89,12 @@ def __init__(

self._first_skip_check_warning = True

# Inspector for runtime information, for example input data, memory usage, etc.
self._runtime_inspector = RuntimeInspector(self._logger, self._original_module)
self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step)

# Tracker for ORTModule model export, session creation overhead.
self.time_tracker = _logger.TimeTracker()

# Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL
# To be instantiated in the concrete implementation of GraphExecutionManager
self._export_mode = None
self._export_mode = export_mode

# Exporter can take extra arguments for ORTModule extensions
# It cannot overlap with required/immutable arguments (validated in runtime)
Expand Down Expand Up @@ -129,6 +126,12 @@ def __init__(
# Re-export will be avoided if _skip_check is enabled.
self._original_model_has_changed = False

# Inspector for runtime information, for example input data, memory usage, etc.
self._runtime_inspector = RuntimeInspector(
self._logger, self._original_module, self._export_mode == torch.onnx.TrainingMode.TRAINING
)
self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step)

# Load ATen operator executor extension.
load_aten_op_executor_cpp_extension()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class InferenceManager(GraphExecutionManager):
"""

def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger):
super().__init__(model, debug_options, fallback_manager, logger)
self._export_mode = torch.onnx.TrainingMode.EVAL
super().__init__(model, debug_options, torch.onnx.TrainingMode.EVAL, fallback_manager, logger)

@staticmethod
def execution_session_run_forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ class RuntimeInspector:
Runtime inspector for ORTModule.
"""

def __init__(self, logger: Logger, module: torch.nn.Module):
def __init__(self, logger: Logger, module: torch.nn.Module, training: bool):
self._logger = logger

self.input_density_ob: Union[InputDensityObserver, None] = None
self.memory_ob = MemoryObserver(module, self._logger)
self.memory_ob = MemoryObserver(module, self._logger, training)

def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None:
"""Initialize input inspector from the given ONNX model and user input names.
Expand Down Expand Up @@ -479,7 +479,7 @@ class MemoryObserver:
NORMALIZER_FACTOR = float(1024 * 1024)
NORMALIZER_UNIT = "MiB"

def __init__(self, m: torch.nn.Module, logger: Logger):
def __init__(self, m: torch.nn.Module, logger: Logger, training: bool):
self._logger = logger
self._is_enabled = True

Expand All @@ -503,7 +503,10 @@ def __init__(self, m: torch.nn.Module, logger: Logger):

self._rank_info = f"[{self._rank}/{self._world_size}]"
self._pre_phase = Phase.INVALID
self._last_phase = Phase.POST_BACKWARD if m.training else Phase.POST_FORWARD

# Cannot infer it is for training or inferencing pupose from module.mode,
# because it probabbly is not set correctly when this happens.
self._last_phase = Phase.POST_BACKWARD if training else Phase.POST_FORWARD

self._is_first_inspect = True

Expand Down Expand Up @@ -721,7 +724,7 @@ def _get_user_config_without_freq(configs: str):
notes.append(saving_recommendation)

saving_recommendation = (
"[Memory Optimizer] memory saving is calculated based on the 1st batch symbolic dim values:\n"
"[Memory Optimizer] Memory saving is calculated based on the 1st batch symbolic dim values:\n"
)
for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items():
saving_recommendation += f" {dim_param}={dim_value},"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def __init__(
fallback_manager: _FallbackManager,
logger: Logger,
):
super().__init__(model, debug_options, fallback_manager, logger)

self._export_mode = torch.onnx.TrainingMode.TRAINING
super().__init__(model, debug_options, torch.onnx.TrainingMode.TRAINING, fallback_manager, logger)
self._forward_class = self._create_autofunction_class()

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6496,3 +6496,39 @@ def run_step(model, x, y, z):
torch.cuda.synchronize()
if original_val is not None:
os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val


def test_bert_memory_inspection(caplog):
original_val = os.environ.get("ORTMODULE_PRINT_MEMORY_STATS", None)

# Create PyTorch model with dropout disabled.
pt_model = _get_bert_for_sequence_classification_model(
"cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
)

os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = "1"
pt_model.eval() # Put it in evaluate mode by intention, in case some initialization in ORTModule use the module.is_training for its checks by mistake.
ort_model = ORTModule(
copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO) # The logged memory info is in INFO level.
)

def run_step(model, x, y, z):
outputs = model(x, y, None, None, None, None, z)
loss = outputs[0]
loss.backward()

ort_model.train()
for _ in range(32):
x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")
run_step(ort_model, x, y, z)

info_records = [
record.message for record in caplog.records if record.levelname == "INFO" and "(MiB) | phase:" in record.message
]

assert len(info_records) == 4 * 11

# Make sure environment variable is restored to its original value after the run is completed.
torch.cuda.synchronize()
if original_val is not None:
os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = original_val

0 comments on commit 17811c4

Please sign in to comment.