Skip to content

Commit

Permalink
Fix memory stats printing (microsoft#20061)
Browse files Browse the repository at this point in the history
### Fix memory stats printing

The mmeory stats printing is failed when module is in eval mode, doing
ORTModule wrap. At that time, runtime inspector for training manager
should have training model being true, but got a false (because existing
logic get the boolean from module.training). Runtime inspector as part
of training manager or inference manager should know it is serving
training or not explicitly, so we cannot depend on the stat of
module.training during ORTModule initialization.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored and Ted Themistokleous committed May 7, 2024
1 parent ba81503 commit 0a36e1d
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,20 @@ def __init__(
self,
module: _FlattenedModule,
debug_options: DebugOptions,
export_mode: int,
fallback_manager: _FallbackManager,
logger: logging.Logger,
):
"""Manages construction and execution of ONNX graphs"""
"""Manages construction and execution of ONNX graphs.
Args:
module: The flatten PyTorch module to be executed.
debug_options: Debug options for ORTModule.
export_mode: export mode, should be torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL.
fallback_manager: Fallback manager to handle exceptions.
logger: Logger for ORTModule.
"""

super().__init__(module._original_module)

Expand Down Expand Up @@ -88,16 +98,12 @@ def __init__(

self._first_skip_check_warning = True

# Inspector for runtime information, for example input data, memory usage, etc.
self._runtime_inspector = RuntimeInspector(self._logger, self._original_module)
self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step)

# Tracker for ORTModule model export, session creation overhead.
self.time_tracker = _logger.TimeTracker()

# Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL
# To be instantiated in the concrete implementation of GraphExecutionManager
self._export_mode = None
self._export_mode = export_mode

# Exporter can take extra arguments for ORTModule extensions
# It cannot overlap with required/immutable arguments (validated in runtime)
Expand Down Expand Up @@ -129,6 +135,12 @@ def __init__(
# Re-export will be avoided if _skip_check is enabled.
self._original_model_has_changed = False

# Inspector for runtime information, for example input data, memory usage, etc.
self._runtime_inspector = RuntimeInspector(
self._logger, self._original_module, self._export_mode == torch.onnx.TrainingMode.TRAINING
)
self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step)

# Load ATen operator executor extension.
load_aten_op_executor_cpp_extension()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class InferenceManager(GraphExecutionManager):
"""

def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger):
super().__init__(model, debug_options, fallback_manager, logger)
self._export_mode = torch.onnx.TrainingMode.EVAL
super().__init__(model, debug_options, torch.onnx.TrainingMode.EVAL, fallback_manager, logger)

@staticmethod
def execution_session_run_forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,18 @@ class RuntimeInspector:
Runtime inspector for ORTModule.
"""

def __init__(self, logger: Logger, module: torch.nn.Module):
def __init__(self, logger: Logger, module: torch.nn.Module, training: bool):
"""Initialize runtime inspector.
Args:
logger: Logger.
module: Torch module.
training: a boolean indicating whether the module is in training mode.
"""
self._logger = logger

self.input_density_ob: Union[InputDensityObserver, None] = None
self.memory_ob = MemoryObserver(module, self._logger)
self.memory_ob = MemoryObserver(module, self._logger, training)

def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None:
"""Initialize input inspector from the given ONNX model and user input names.
Expand Down Expand Up @@ -479,7 +486,14 @@ class MemoryObserver:
NORMALIZER_FACTOR = float(1024 * 1024)
NORMALIZER_UNIT = "MiB"

def __init__(self, m: torch.nn.Module, logger: Logger):
def __init__(self, m: torch.nn.Module, logger: Logger, training: bool):
"""Initialize memory observer.
Args:
m: Torch module.
logger: Logger.
training: a boolean indicating whether the module is in training mode.
"""
self._logger = logger
self._is_enabled = True

Expand All @@ -503,7 +517,10 @@ def __init__(self, m: torch.nn.Module, logger: Logger):

self._rank_info = f"[{self._rank}/{self._world_size}]"
self._pre_phase = Phase.INVALID
self._last_phase = Phase.POST_BACKWARD if m.training else Phase.POST_FORWARD

# Cannot infer it is for training or inferencing purpose from module.training,
# because it probabbly is not set correctly when this happens.
self._last_phase = Phase.POST_BACKWARD if training else Phase.POST_FORWARD

self._is_first_inspect = True

Expand Down Expand Up @@ -721,7 +738,7 @@ def _get_user_config_without_freq(configs: str):
notes.append(saving_recommendation)

saving_recommendation = (
"[Memory Optimizer] memory saving is calculated based on the 1st batch symbolic dim values:\n"
"[Memory Optimizer] Memory saving is calculated based on the 1st batch symbolic dim values:\n"
)
for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items():
saving_recommendation += f" {dim_param}={dim_value},"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def __init__(
fallback_manager: _FallbackManager,
logger: Logger,
):
super().__init__(model, debug_options, fallback_manager, logger)

self._export_mode = torch.onnx.TrainingMode.TRAINING
super().__init__(model, debug_options, torch.onnx.TrainingMode.TRAINING, fallback_manager, logger)
self._forward_class = self._create_autofunction_class()

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6499,6 +6499,42 @@ def run_step(model, x, y, z):
os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val


def test_bert_memory_inspection(caplog):
original_val = os.environ.get("ORTMODULE_PRINT_MEMORY_STATS", None)

# Create PyTorch model with dropout disabled.
pt_model = _get_bert_for_sequence_classification_model(
"cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
)

os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = "1"
pt_model.eval() # Put it in evaluate mode by intention, in case some initialization in ORTModule use the module.is_training for its checks by mistake.
ort_model = ORTModule(
copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO) # The logged memory info is in INFO level.
)

def run_step(model, x, y, z):
outputs = model(x, y, None, None, None, None, z)
loss = outputs[0]
loss.backward()

ort_model.train()
for _ in range(32):
x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")
run_step(ort_model, x, y, z)

info_records = [
record.message for record in caplog.records if record.levelname == "INFO" and "(MiB) | phase:" in record.message
]

assert len(info_records) == 4 * 11

# Make sure environment variable is restored to its original value after the run is completed.
torch.cuda.synchronize()
if original_val is not None:
os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = original_val


@pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32])
def test_overridden_softmax_export(softmax_compute_type):
class CustomSoftmaxExportTest(torch.nn.Module):
Expand Down

0 comments on commit 0a36e1d

Please sign in to comment.