Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory tool chain improvement #18890

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,15 @@
return;
}

for (const auto& plans : all_possible_node_optimization_plans[index]) {
for (const auto& plan : plans) {
InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>> new_combination = current_combination;
new_combination.push_back(plan);
ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations);
}
const InlinedVector<InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>>>&
plan_combination_list_at_cur_index = all_possible_node_optimization_plans[index];
// For the index-th reused buffer, iterate all possible complete plans.
for (size_t i = 0; i < plan_combination_list_at_cur_index.size(); ++i) {
const auto& plan_combination = plan_combination_list_at_cur_index[i];
InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>> new_combination = current_combination;
// Append the chosen complete plan and continue exploring the next reused buffer by index + 1.
new_combination.insert(new_combination.end(), plan_combination.begin(), plan_combination.end());
ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations);
}

MO_LOG_DEBUG_INFO(logger, "Exit ListAllCombinations");
Expand Down Expand Up @@ -520,17 +523,29 @@
}

InlinedVector<InlinedVector<InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>>>>
all_possible_node_optimization_plans;
all_possible_node_optimization_plans.resize(plan->reuse_buffers.size());
all_possible_node_optimization_plans(plan->reuse_buffers.size());

size_t i = 0;
for (const auto& p : plan->reuse_buffers) {
MO_LOG_DEBUG_INFO(logger, ">>>reuse buffer: " + std::to_string(p.first));
IterateNode(p.second.first, node_to_optimization_plans_map, {}, logger, all_possible_node_optimization_plans[i]);

// If the resued node is part of current node optimization plan, then we just add current combination to the result.
if (plan->GetOptimizationType() == OptimizationType::RecomputeWithCompromise || plan->GetOptimizationType() == OptimizationType::Recompute) {

Check warning on line 533 in orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc#L533

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc:533:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
const auto& recompute_subgraph =
dynamic_cast<NodeRecomputePlan*>(plan.get())->GetNodesInTopoOrder();
if (std::find(recompute_subgraph.begin(), recompute_subgraph.end(), p.second.first) != recompute_subgraph.end()) {
all_possible_node_optimization_plans[i].push_back(current_combination);
}
}

if (all_possible_node_optimization_plans[i].size() == 0) {
IterateNode(p.second.first, node_to_optimization_plans_map, current_combination, logger, all_possible_node_optimization_plans[i]);

Check warning on line 542 in orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc#L542

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc:542:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}

++i;
}

ListAllCombinations(all_possible_node_optimization_plans, 0, current_combination, logger, all_combinations);
ListAllCombinations(all_possible_node_optimization_plans, 0, {}, logger, all_combinations);

MO_LOG_DEBUG_INFO(logger, "Exit IterateNodeOptimizationPlan: " + plan->GetClusterId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,6 @@

namespace onnxruntime::optimizer::memory_optimizer {

std::string NodeOptimizationPlanBase::GetMemorySavingSymbolicString() const {
std::string saving_str;
for (auto output_index : activation_output_indices_) {
// If the output is reusing other node's buffer, then no memory saving.
if (reuse_buffers.find(output_index) != reuse_buffers.end()) {
continue;
}

const auto& output_def = node->OutputDefs()[output_index];
MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto());
ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ",
DataTypeImpl::ToString(ml_data_type));
const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType();
ORT_ENFORCE(nullptr != tensor_type_base);
MLDataType elt_type = tensor_type_base->GetElementType();
const auto byte_count_per_element = elt_type->Size();
if (!saving_str.empty()) {
saving_str += " + ";
}
saving_str = "(" + GetActivationOutputDimParamString(output_index) + " * " +
std::to_string(byte_count_per_element) + " * " +
std::to_string(GetSaveRatio()) + ")";
}
if (saving_str.empty()) {
return saving_str;
}
return "(" + saving_str + ")";
}

Status MemoryOptimizationPlanner::UpdateNodePlansFromExecutionPlan(const GraphViewer& graph_viewer,
const OrtValueNameIdxMap& ortvalue_name_to_idx_map,
const SequentialExecutionPlan& p_seq_exec_plan) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class NodeOptimizationPlanBase {
/**
* Get a symbolic string to represent the memory saving for this optimization plan.
*/
std::string GetMemorySavingSymbolicString() const;
virtual std::string GetMemorySavingSymbolicString() const = 0;

std::string GetActivationOutputDimParamString(size_t index) const {
ORT_ENFORCE(activation_output_dim_params_.find(index) != activation_output_dim_params_.end(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ const InlinedHashMap<std::string, AllowedRecomputeNodeConfig>& GetAllowedRecompu
{"Add", AllowedRecomputeNodeConfig{{0, 1}}},
{"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}},
{"Div", AllowedRecomputeNodeConfig{{0, 1}}},
{"Equal", AllowedRecomputeNodeConfig{{0, 1}}},
{"Mul", AllowedRecomputeNodeConfig{{0, 1}}},
{"Sub", AllowedRecomputeNodeConfig{{0, 1}}},

// Data layout
/// The shape input is trivial whether it exists or not in backward.
{"Shape", AllowedRecomputeNodeConfig{{0}}},
{"Reshape", AllowedRecomputeNodeConfig{{0}}},
{"Squeeze", AllowedRecomputeNodeConfig{{0}}},
{"Transpose", AllowedRecomputeNodeConfig{{0}}},
Expand All @@ -92,6 +94,7 @@ const InlinedHashMap<std::string, AllowedRecomputeNodeConfig>& GetAllowedRecompu
{"Expand", AllowedRecomputeNodeConfig{{0}}},
{"FastGelu", AllowedRecomputeNodeConfig{{0}}},
{"Gelu", AllowedRecomputeNodeConfig{{0}}},
{"QuickGelu", AllowedRecomputeNodeConfig{{0}}},

// Ternary elementwise
{"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,51 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase {

std::string GetNodesInTopoOrderStr() const;

std::string GetMemorySavingSymbolicString() const override {
std::string saving_str;
for (auto output_index : GetActivationOutputIndices()) {
// If the output is reusing other node's buffer, then no memory saving.
std::string cur_output_saving_str;

bool is_reused = reuse_buffers.find(output_index) != reuse_buffers.end();
bool is_src_node_in_cur_node_subgraph = false;
if (is_reused) {
// Here we assume the src_node is the real owner of the buffer, so we don't need trace further.
const auto* src_node = reuse_buffers.at(output_index).first;
is_src_node_in_cur_node_subgraph = std::find(nodes_in_topological_order_.begin(),
nodes_in_topological_order_.end(),
src_node) != nodes_in_topological_order_.end();
}

if (!is_reused || is_src_node_in_cur_node_subgraph) {
// For is_src_node_in_cur_node_subgraph is True, still use the output to calculate the saving, because
// reusing buffer is the same size.
const auto& output_def = node->OutputDefs()[output_index];
MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto());
ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ",
DataTypeImpl::ToString(ml_data_type));
const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType();
ORT_ENFORCE(nullptr != tensor_type_base);
MLDataType elt_type = tensor_type_base->GetElementType();
const auto byte_count_per_element = elt_type->Size();
cur_output_saving_str = GetActivationOutputDimParamString(output_index) + " * " +
std::to_string(byte_count_per_element) + " * " +
std::to_string(GetSaveRatio());
} else {
cur_output_saving_str = "0";
}

if (!saving_str.empty()) {
saving_str += " + ";
}

saving_str = "(" + cur_output_saving_str + ")";
}

ORT_ENFORCE(!saving_str.empty(), "saving_str should not be empty for node: ", node->OpType(), " ", node->Name());
return "(" + saving_str + ")";
}

private:
bool compromise_recompute_;
InlinedVector<const Node*> nodes_in_topological_order_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
import torch.utils.checkpoint
from onnx import ModelProto
from onnx import ModelProto, helper
from packaging import version
from torch.onnx import symbolic_helper

Expand Down Expand Up @@ -393,6 +393,31 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model
node.name = f"{op_name_prefix}_id_{index}"
index += 1

from onnxruntime.training.utils.hooks._mem_statistics_subscriber import _InspectMemoryUsage
from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation
from onnxruntime.training.utils.hooks._subscriber_manager import _IncrementStep

_allowed_unsafe_run_python_op_names = [
get_fully_qualified_class_name(_InspectMemoryUsage),
get_fully_qualified_class_name(_IncrementStep),
get_fully_qualified_class_name(_InspectActivation),
]

for node in exported_model.graph.node:
if node.op_type == "PythonOp":
func_name = None
safe_run_mode_attr = None
for attr in node.attribute:
if attr.name == "func_name":
func_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
if attr.name == "safe_run_mode":
safe_run_mode_attr = attr

if func_name in _allowed_unsafe_run_python_op_names:
if safe_run_mode_attr:
node.attribute.remove(safe_run_mode_attr)
node.attribute.append(helper.make_attribute("safe_run_mode", 0))

return exported_model


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sympy import Symbol, simplify
from sympy.parsing.sympy_parser import parse_expr

from onnxruntime.training.utils import PTable
from onnxruntime.training.utils import PTable, log_memory_usage

from ._execution_agent import TrainingAgent
from .options import _MemoryOptimizationLevel, _RuntimeOptions
Expand Down Expand Up @@ -509,6 +509,8 @@ def __init__(self, m: torch.nn.Module, logger: Logger):

self._is_first_inspect = True

self._m = m

def is_enabled(self) -> bool:
"""Check if memory inspector is enabled."""
return self._is_enabled
Expand Down Expand Up @@ -621,29 +623,13 @@ def inspect_memory(self, cur_phase: Phase):
need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0)

if need_print:
cur_mem_allocated = self._normalize(torch.cuda.memory_allocated())
max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated())
cur_mem_cached = self._normalize(torch.cuda.memory_reserved())
max_mem_cached = self._normalize(torch.cuda.max_memory_reserved())
torch_mem_stat = torch.cuda.memory_stats()
cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0))
max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0))

mem_stats = [
["phase", _convert_phase_to_string(cur_phase)],
["allocated", cur_mem_allocated], # current memory allocated for tensors
["max allocated", max_mem_allocated], # peak memory allocated for tensors
["cached", cur_mem_cached], # current memory cached for the caching allocator
["max cached", max_mem_cached], # peak memory cached for caching allocator.
["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory
["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory
]

summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})"
for stat in mem_stats:
summ += f" | {stat[0]}: {stat[1]}"

self._logger.info(summ)
log_memory_usage(
_convert_phase_to_string(cur_phase),
rank_0_only=True,
step_info=f"step {self._current_step}",
logger=self._logger,
module=self._m,
)

if cur_phase == self._last_phase:
self._increase_step()
Expand All @@ -655,9 +641,6 @@ def inspect_memory(self, cur_phase: Phase):
def _increase_step(self):
self._current_step += 1

def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str:
return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}"

def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]:
mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def forward(self, *inputs, **kwargs):
# Only change this after the firs time a warning is issued.
self._first_skip_check_warning = False
self._logger.info(
"Fast path enabled - skipping checks.Rebuild graph: %s, Execution agent: %s, Device check: %s",
"Fast path enabled - skipping checks. Rebuild graph: %s, Execution agent: %s, Device check: %s",
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,13 @@ def _get_func_name(node: NodeProto) -> Optional[str]:
exported_model.graph.node.insert(0, weight_pull_node)

# Update safe_run_mode attribute for PythonOp.
from onnxruntime.training.utils.hooks._subscriber_manager import _IncrementStep

_allowed_unsafe_run_python_op_names = [
get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction),
get_fully_qualified_class_name(ORTZeROOffloadPostForwardFunction),
func_full_qual_name,
DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME,
DEEPSPEED_POST_BACKWARD_FUNCTION_NAME,
DEEPSPEED_LINEAR_FUNCTION_NAME,
get_fully_qualified_class_name(_IncrementStep),
]

for node in exported_model.graph.node:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ py::object finalize_training_mode_forward(
}

if (kernel_info.is_first_run) {
std::cout << "666666666666666666666666. py_fn->materialize_grads:" << py_fn->materialize_grads << std::endl;
get_materialize_grads_once(forward_output_tensors, py_fn->materialize_grads, kernel_info);

if (kernel_info.safe_run_enabled) {
Expand Down
2 changes: 2 additions & 0 deletions orttraining/orttraining/python/training/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
unflatten_data_using_schema,
)
from onnxruntime.training.utils.torch_profile_utils import (
log_memory_usage,
nvtx_function_decorator,
torch_nvtx_range_pop,
torch_nvtx_range_push,
Expand All @@ -31,6 +32,7 @@
"torch_nvtx_range_push",
"torch_nvtx_range_pop",
"nvtx_function_decorator",
"log_memory_usage",
"pytorch_type_to_onnx_dtype",
"onnx_dtype_to_pytorch_dtype",
"pytorch_scalar_type_to_pytorch_dtype",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

__all__ = [
"StatisticsSubscriber",
"MemoryStatisticsSubscriber",
"GlobalSubscriberManager",
"inspect_activation",
"ZeROOffloadSubscriber",
"configure_ort_compatible_zero_stage3",
]

from ._mem_statistics_subscriber import MemoryStatisticsSubscriber
from ._statistics_subscriber import StatisticsSubscriber, _InspectActivation
from ._subscriber_manager import SubscriberManager
from ._zero_offload_subscriber import ZeROOffloadSubscriber, configure_ort_compatible_zero_stage3
Expand Down
Loading
Loading