Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory stats printing #20061

Merged
merged 5 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,20 @@ def __init__(
self,
module: _FlattenedModule,
debug_options: DebugOptions,
export_mode: int,
pengwa marked this conversation as resolved.
Show resolved Hide resolved
fallback_manager: _FallbackManager,
logger: logging.Logger,
):
"""Manages construction and execution of ONNX graphs"""
"""Manages construction and execution of ONNX graphs.

Args:
module: The flatten PyTorch module to be executed.
debug_options: Debug options for ORTModule.
export_mode: export mode, should be torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL.
fallback_manager: Fallback manager to handle exceptions.
logger: Logger for ORTModule.

"""

super().__init__(module._original_module)

Expand Down Expand Up @@ -88,16 +98,12 @@ def __init__(

self._first_skip_check_warning = True

# Inspector for runtime information, for example input data, memory usage, etc.
self._runtime_inspector = RuntimeInspector(self._logger, self._original_module)
self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step)

# Tracker for ORTModule model export, session creation overhead.
self.time_tracker = _logger.TimeTracker()

# Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL
# To be instantiated in the concrete implementation of GraphExecutionManager
self._export_mode = None
self._export_mode = export_mode

# Exporter can take extra arguments for ORTModule extensions
# It cannot overlap with required/immutable arguments (validated in runtime)
Expand Down Expand Up @@ -129,6 +135,12 @@ def __init__(
# Re-export will be avoided if _skip_check is enabled.
self._original_model_has_changed = False

# Inspector for runtime information, for example input data, memory usage, etc.
self._runtime_inspector = RuntimeInspector(
self._logger, self._original_module, self._export_mode == torch.onnx.TrainingMode.TRAINING
)
self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step)

# Load ATen operator executor extension.
load_aten_op_executor_cpp_extension()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class InferenceManager(GraphExecutionManager):
"""

def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger):
super().__init__(model, debug_options, fallback_manager, logger)
self._export_mode = torch.onnx.TrainingMode.EVAL
super().__init__(model, debug_options, torch.onnx.TrainingMode.EVAL, fallback_manager, logger)

@staticmethod
def execution_session_run_forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,18 @@ class RuntimeInspector:
Runtime inspector for ORTModule.
"""

def __init__(self, logger: Logger, module: torch.nn.Module):
def __init__(self, logger: Logger, module: torch.nn.Module, training: bool):
"""Initialize runtime inspector.

Args:
logger: Logger.
module: Torch module.
training: a boolean indicating whether the module is in training mode.
"""
self._logger = logger

self.input_density_ob: Union[InputDensityObserver, None] = None
self.memory_ob = MemoryObserver(module, self._logger)
self.memory_ob = MemoryObserver(module, self._logger, training)

def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None:
"""Initialize input inspector from the given ONNX model and user input names.
Expand Down Expand Up @@ -479,7 +486,14 @@ class MemoryObserver:
NORMALIZER_FACTOR = float(1024 * 1024)
NORMALIZER_UNIT = "MiB"

def __init__(self, m: torch.nn.Module, logger: Logger):
def __init__(self, m: torch.nn.Module, logger: Logger, training: bool):
pengwa marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize memory observer.

Args:
m: Torch module.
logger: Logger.
training: a boolean indicating whether the module is in training mode.
"""
self._logger = logger
self._is_enabled = True

Expand All @@ -503,7 +517,10 @@ def __init__(self, m: torch.nn.Module, logger: Logger):

self._rank_info = f"[{self._rank}/{self._world_size}]"
self._pre_phase = Phase.INVALID
self._last_phase = Phase.POST_BACKWARD if m.training else Phase.POST_FORWARD

# Cannot infer it is for training or inferencing purpose from module.training,
# because it probabbly is not set correctly when this happens.
self._last_phase = Phase.POST_BACKWARD if training else Phase.POST_FORWARD

self._is_first_inspect = True

Expand Down Expand Up @@ -721,7 +738,7 @@ def _get_user_config_without_freq(configs: str):
notes.append(saving_recommendation)

saving_recommendation = (
"[Memory Optimizer] memory saving is calculated based on the 1st batch symbolic dim values:\n"
"[Memory Optimizer] Memory saving is calculated based on the 1st batch symbolic dim values:\n"
)
for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items():
saving_recommendation += f" {dim_param}={dim_value},"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def __init__(
fallback_manager: _FallbackManager,
logger: Logger,
):
super().__init__(model, debug_options, fallback_manager, logger)

self._export_mode = torch.onnx.TrainingMode.TRAINING
super().__init__(model, debug_options, torch.onnx.TrainingMode.TRAINING, fallback_manager, logger)
self._forward_class = self._create_autofunction_class()

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6499,6 +6499,42 @@ def run_step(model, x, y, z):
os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val


def test_bert_memory_inspection(caplog):
original_val = os.environ.get("ORTMODULE_PRINT_MEMORY_STATS", None)
wschin marked this conversation as resolved.
Show resolved Hide resolved

# Create PyTorch model with dropout disabled.
pt_model = _get_bert_for_sequence_classification_model(
"cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
)

os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = "1"
pt_model.eval() # Put it in evaluate mode by intention, in case some initialization in ORTModule use the module.is_training for its checks by mistake.
ort_model = ORTModule(
copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO) # The logged memory info is in INFO level.
)

def run_step(model, x, y, z):
outputs = model(x, y, None, None, None, None, z)
loss = outputs[0]
loss.backward()

ort_model.train()
for _ in range(32):
x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")
run_step(ort_model, x, y, z)

info_records = [
record.message for record in caplog.records if record.levelname == "INFO" and "(MiB) | phase:" in record.message
]

assert len(info_records) == 4 * 11

# Make sure environment variable is restored to its original value after the run is completed.
torch.cuda.synchronize()
if original_val is not None:
os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = original_val


@pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32])
def test_overridden_softmax_export(softmax_compute_type):
class CustomSoftmaxExportTest(torch.nn.Module):
Expand Down
Loading