Skip to content

Commit

Permalink
move files
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Nov 23, 2023
1 parent 1df5dd7 commit deedc44
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
#ifdef ENABLE_TRAINING
#include "core/framework/partial_graph_execution_state.h"
#include "core/framework/stream_execution_context.h"
#include "orttraining/core/optimizer/memory_optimizer.h"
#include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h"
#endif

using namespace ONNX_NAMESPACE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "core/graph/graph_utils.h"
#include "core/optimizer/utils.h"
#include "orttraining/core/graph/recompute_graph_utils.h"
#include "orttraining/core/optimizer/memory_optimizer.h"
#include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h"
#include "orttraining/core/optimizer/memory_optimizer/common.h"
#include "orttraining/core/optimizer/memory_optimizer/optimization_planner.h"
#include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ namespace onnxruntime {
/**
@Class MemoryOptimizer
(TODO) move to orttraining/orttraining/core/optimizer/memory_optimizer/ folder.
Find recompute subgraphs and enable them according to user configs. The way we collect subgraphs
(in orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h) in brief is:
1. Find all nodes that generate stashed activations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,6 @@ def _log_feature_stats(self):
if get_rank() != 0:
return

if self._runtime_inspector.memory_ob.is_enabled() and self._debug_options.log_level <= LogLevel.DEVINFO:
self._logger.info(self._runtime_inspector.memory_ob.memory_optimization_opportunity_table_str)

tbl = PTable(sortable=True)

def _add_record(tbl, columns):
Expand All @@ -654,14 +651,17 @@ def _add_record(tbl, columns):
],
)

opt_config_to_display = self._runtime_options.memory_optimizer_config
if self._runtime_options.memory_optimization_level == _MemoryOptimizationLevel.AGGRESSIVE_FULL_RECOMPUTE:
opt_config_to_display = "ALL_RECOMPUTE_CONFIGS"
mem_row = _add_record(
tbl,
[
"Memory Optimizer",
len(self._runtime_options.memory_optimizer_config) > 0,
(
f"Memory Optimization Level: [{_MemoryOptimizationLevel.to_string(self._runtime_options.memory_optimization_level)}], "
f"Optimization Config: [{self._runtime_options.memory_optimizer_config}], "
f"Optimization Config: [{opt_config_to_display}], "
f"Probe Level: [{self._runtime_options.probe_level}]"
if len(self._runtime_options.memory_optimizer_config) > 0
else "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,..."
Expand Down
4 changes: 3 additions & 1 deletion orttraining/orttraining/python/training/ortmodule/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def is_disabled(self):
class _MemoryOptimizationLevel(IntFlag):
"""Enumeration to specify memory optimization level"""

USER_SPECIFIED = 0 # Fully respect user specified config
USER_SPECIFIED = 0 # Fully respect user-specified config
AGGRESSIVE_FULL_RECOMPUTE = 1 # Enable all recomputable subgraphs

@staticmethod
Expand All @@ -206,6 +206,8 @@ def to_string(memory_optimization_level):
if memory_optimization_level == _MemoryOptimizationLevel.AGGRESSIVE_FULL_RECOMPUTE:
return "AGGRESSIVE_FULL_RECOMPUTE"

return ""


class _RuntimeOptions:
"""Configurable runtime options for ORTModule."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include "test/capturing_sink.h"
#include "test/test_environment.h"
#include "test/util/include/asserts.h"
#include "orttraining/core/optimizer/memory_optimizer.h"
#include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h"

using namespace std;
using namespace ONNX_NAMESPACE;
Expand Down

0 comments on commit deedc44

Please sign in to comment.