Skip to content

Commit

Permalink
Improve perf for mem efficient grad mgmt (microsoft#20480)
Browse files Browse the repository at this point in the history
### Improve perf for mem efficient grad mgmt

When memory efficient gradient mangement feature is enabled, the weight
retrieval PythonOp for every layers will be launched at the beginning of
the forward, which would make GPU stream idle for few milliseconds. The
reason is the ReversedDFS ordering cannot ALWAYS handle such input
branching well, so we introduce a distantance-to-input_leaf concepts
when doing the reversedDFS, which not only move the problematical
PythonOp to the place where it is needed, but also those Cast ops
following the weight retrieval to the place where it is needed.

Main branch: 102.19 - 26.35s = 75.84s for 260 steps(4627samples),
61.04sample/second
This PR: 100.28s - 25.10s = 75.18s for 260 steps. 61.54samples/second
(+0.8% gains)

Main branch:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/75c4131e-dade-49b0-aa8b-ee1c637ad9a8)


This PR:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/e590a536-3b80-4f51-b89f-f25a55ddd7e2)


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored May 10, 2024
1 parent 5a18818 commit 56f7035
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 102 deletions.
62 changes: 57 additions & 5 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,22 @@ struct GroupNode {
InlinedVector<const Node*> nodes;
};

struct NodeCompareByMaxDistance {
explicit NodeCompareByMaxDistance(const InlinedHashMap<NodeIndex, int>& max_distance)
: max_distance_(max_distance) {}
bool operator()(const Node* n1, const Node* n2) const {
if (max_distance_.at(n1->Index()) != max_distance_.at(n2->Index())) {
// The longer distance node should be executed first.
return max_distance_.at(n1->Index()) < max_distance_.at(n2->Index());
}

return n1->Index() < n2->Index();
}

private:
const InlinedHashMap<NodeIndex, int>& max_distance_;
};

void SortForwardNodesByReverseDFS(const Graph* graph,
const InlinedVector<const Node*>& forward_output_nodes,
const InlinedHashMap<NodeIndex, InlinedVector<NodeIndex>>& shape_size_parents,
Expand All @@ -1941,6 +1957,38 @@ void SortForwardNodesByReverseDFS(const Graph* graph,
// Note 2: While it is also possible some nodes not contributing to the forward output nodes will be
// executed before YieldOp, for example, if one forward node's output is used by Shape/Size, then
// the Shape/Size node should be executed before YieldOp to release the memory as soon as possible.
InlinedVector<size_t> nodes_in_degree;
std::queue<const Node*> to_visit;
nodes_in_degree.resize(graph->MaxNodeIndex(), 0);
for (auto& node : graph->Nodes()) {
size_t input_edge_count = node.GetInputEdgesCount();
nodes_in_degree[node.Index()] = input_edge_count;
if (input_edge_count == 0) {
to_visit.push(&node);
}
}

InlinedHashMap<NodeIndex, int> max_distance;
max_distance.reserve(graph->NumberOfNodes());
while (!to_visit.empty()) {
const Node* current = to_visit.front();
to_visit.pop();

if (!current) continue;

for (auto output_edge_it = current->OutputEdgesBegin();
output_edge_it != current->OutputEdgesEnd();
++output_edge_it) {
const Node* out_node = &output_edge_it->GetNode();
max_distance[out_node->Index()] = std::max(max_distance[out_node->Index()],
max_distance[current->Index()] + 1);
auto& node_in_degree = nodes_in_degree[out_node->Index()];
node_in_degree--;
if (node_in_degree == 0) {
to_visit.push(out_node);
}
}
}

// Reverse DFS from forward output nodes to find all "forward" nodes.
// The forward nodes are ordered by Reverse DFS tranverse.
Expand All @@ -1951,7 +1999,7 @@ void SortForwardNodesByReverseDFS(const Graph* graph,
nodes_to_execute_before_yieldop.insert(n);
node_orders.push_back(n->Index());
},
NodeCompare());
NodeCompareByMaxDistance(max_distance));

for (const auto& parent_to_children_pair : shape_size_parents) {
const NodeIndex& parent_index = parent_to_children_pair.first;
Expand All @@ -1976,13 +2024,13 @@ void SortForwardNodesByReverseDFS(const Graph* graph,
}

void PrepareToFindBranchGraph(const Graph* graph,
const InlinedHashSet<const Node*>& nodes_to_execute_before_yieldop,
std::function<bool(const Node*)> is_forward_node,
InlinedVector<const Node*>& branch_graph_input_nodes,
InlinedVector<size_t>& backward_node_in_degree,
std::queue<const Node*>& to_visit) {
for (auto& node : graph->Nodes()) {
// Ignore forward.
if (nodes_to_execute_before_yieldop.find(&node) != nodes_to_execute_before_yieldop.end()) {
if (is_forward_node(&node)) {
continue;
}

Expand All @@ -2004,7 +2052,7 @@ void PrepareToFindBranchGraph(const Graph* graph,
for (auto input_edge_it = node.InputEdgesBegin(); input_edge_it != node.InputEdgesEnd(); ++input_edge_it) {
const Node* input_node = &input_edge_it->GetNode();
// If the input edge connect to forward nodes, then we remove the in_degree of the node.
if (nodes_to_execute_before_yieldop.find(input_node) != nodes_to_execute_before_yieldop.end()) {
if (is_forward_node(input_node)) {
input_edge_count--;
}
}
Expand Down Expand Up @@ -2203,11 +2251,15 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op,
topo_order.reserve(num_of_backward_nodes);
std::queue<const Node*> to_visit;

auto is_forward_op = [&nodes_to_execute_before_yieldop](const Node* n) -> bool {
return nodes_to_execute_before_yieldop.find(n) != nodes_to_execute_before_yieldop.end();
};

InlinedVector<const Node*> branch_graph_input_nodes;
branch_graph_input_nodes.reserve(num_of_backward_nodes);

PrepareToFindBranchGraph(this,
nodes_to_execute_before_yieldop,
is_forward_op,
branch_graph_input_nodes,
backward_node_in_degree,
to_visit);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def _export_model(self, *inputs, **kwargs) -> bool:
):
# All required models have already been exported previously
return False

self._set_device_from_module(inputs, kwargs)
# TODO: move it into runtime_inspector
embedding_hook_handles = self._add_check_embedding_sparsity_hook()
Expand All @@ -319,6 +320,7 @@ def _export_model(self, *inputs, **kwargs) -> bool:

for hook in embedding_hook_handles:
hook.remove()

if self._debug_options.save_onnx_models.save:
self._onnx_models.save_exported_model(
self._debug_options.save_onnx_models.path,
Expand Down Expand Up @@ -574,10 +576,11 @@ def _initialize_graph_builder(self):
from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training

# Override the options if model is not modified.
(
self._mem_efficient_grad_management_is_enabled,
exported_model,
) = post_processing_enable_mem_efficient_training(exported_model, self._flattened_module.named_parameters())
(self._mem_efficient_grad_management_is_enabled, exported_model, self._param_trigger_grad) = (
post_processing_enable_mem_efficient_training(
exported_model, self._flattened_module.named_parameters(), self._device
)
)

if self._runtime_options.run_symbolic_shape_infer:
exported_model = SymbolicShapeInference.infer_shapes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import torch
from onnx import ModelProto, NodeProto, TensorProto, helper

from onnxruntime.training.utils import pytorch_type_to_onnx_dtype

from ._pythonop_helper import make_pythonop_node
from onnxruntime.training.ortmodule._pythonop_helper import make_pythonop_node
from onnxruntime.training.utils import onnx_dtype_to_pytorch_dtype, pytorch_type_to_onnx_dtype

MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME = "mem_efficient_pull_weight_trigger"
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE = TensorProto.FLOAT
Expand All @@ -38,54 +37,30 @@ def get_params_not_connected_to_pull_param_trigger(
def post_processing_enable_mem_efficient_training(
exported_model: ModelProto,
named_params: dict[str, torch.nn.parameter.Parameter],
) -> tuple[bool, ModelProto]:
"""This function is used to enable zero stage3 compatibility.
device: torch.device,
) -> tuple[bool, ModelProto, torch.Tensor]:
"""This function is used to enable memory efficient gradient management.
Args:
exported_model (ModelProto): The exported model.
named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The full parameter map.
exported_model: The exported model.
named_params: The full parameter map.
device: The device to run the model.
Returns:
tuple[bool, ModelProto]: A tuple of bool and ModelProto. The bool indicates whether the model is modified.
A tuple of bool, ModelProto and param_trigger_grad tensor. The bool indicates whether the model is modified.
"""
trainable_named_params = get_params_connected_to_pull_param_trigger(named_params, exported_model)
if len(trainable_named_params) == 0:
return False, exported_model
return False, exported_model, None

# Create weight retrieving function using trainable_named_params.
param_pull_trigger_func_class = _create_param_trigger_function(trainable_named_params)
param_retrieve_func_class = _create_param_retrieval_function(trainable_named_params)

def _get_param_pull_trigger_name(param_name: str) -> str:
return f"pull_{param_name}"

# Create weight retrieving PythonOp.
inputs = [
helper.make_tensor_value_info(
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, # Use the same data type with output for the input
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)
]

outputs = [
helper.make_tensor_value_info(
_get_param_pull_trigger_name(pname),
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)
for pname in trainable_named_params
]

weight_pull_node = make_pythonop_node(
"weight_pull_trigger",
inputs,
outputs,
param_pull_trigger_func_class,
training_mode=1,
safe_run_mode=0,
param_trigger_grad = torch.zeros(
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE),
device=device,
)
param_retrieve_func_class = _create_param_retrieval_function(trainable_named_params, param_trigger_grad)

graph_inputs_to_remove = []
input_offset = 0
Expand All @@ -98,7 +73,7 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
# Create the param retrieval function for this parameter.
node_inputs = [
helper.make_tensor_value_info(
_get_param_pull_trigger_name(graph_input.name),
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
),
Expand Down Expand Up @@ -141,75 +116,43 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
if input.name in named_params:
break
offset += 1
exported_model.graph.input.insert(offset, inputs[0])
exported_model.graph.node.insert(0, weight_pull_node)
exported_model.graph.input.insert(
offset,
helper.make_tensor_value_info(
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, # Use the same data type with output for the input
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
),
)

return True, exported_model
return True, exported_model, param_trigger_grad


_PARAM_FUNCTION_INDEX = [0]


def _create_param_trigger_function(trainable_named_params: dict[str, torch.nn.parameter.Parameter]):
"""This function is used to create a weight retrieving function using trainable_named_params."""

@staticmethod
def forward(ctx, weight_in_trigger):
params = list(trainable_named_params.values())
ctx.params = params
ctx.dtype = weight_in_trigger.dtype
ctx.device = weight_in_trigger.device
ctx.shape = weight_in_trigger.shape
return (torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype),) * len(params)

@staticmethod
def backward(ctx, *grad_outputs):
return torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype)

@staticmethod
def infer_shape(
node: NodeProto,
tensor_input_shapes: list[list[int | str] | None],
tensor_input_dtypes: list[torch.onnx.TensorProtoDataType],
) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]:
param_count = len(trainable_named_params.values())
tensor_output_shapes = [
tensor_input_shapes[0],
] * param_count
tensor_output_dtypes = [
tensor_input_dtypes[0],
] * param_count

return tensor_output_shapes, tensor_output_dtypes

_PARAM_FUNCTION_INDEX[0] += 1

return type(
f"ParamTriggerFunction_{_PARAM_FUNCTION_INDEX[0]}",
(torch.autograd.Function,),
{
"forward": forward,
"backward": backward,
"infer_shape": infer_shape,
},
)
def _create_param_retrieval_function(
trainable_named_params: dict[str, torch.nn.parameter.Parameter], param_trigger: torch.Tensor
):
"""This function is used to create a weight retrieving function using trainable_named_params.
Args:
trainable_named_params: The trainable named parameters.
param_trigger: The trigger tensor for pulling the weights. param_trigger is pre-alloced just once
before model execution, later it will be reused by each iteration. This could save the unnecessary
overhead allocating for each iteration run.
def _create_param_retrieval_function(trainable_named_params: dict[str, torch.nn.parameter.Parameter]):
"""This function is used to create a weight retrieving function using trainable_named_params."""
"""

@staticmethod
def forward(ctx, param_trigger, param_name):
ctx.param_name = param_name
ctx.dtype = param_trigger.dtype
ctx.device = param_trigger.device
ctx.shape = param_trigger.shape
return trainable_named_params[param_name]

@staticmethod
def backward(ctx, *grad_outputs):
trainable_named_params[ctx.param_name].backward(grad_outputs[0])
return torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype), None
return param_trigger, None

@staticmethod
def infer_shape(
Expand All @@ -235,6 +178,8 @@ def infer_shape(

return tensor_output_shapes, tensor_output_dtypes

_PARAM_FUNCTION_INDEX[0] += 1

return type(
f"ParamRetrievalFunction_{_PARAM_FUNCTION_INDEX[0]}",
(torch.autograd.Function,),
Expand Down
Loading

0 comments on commit 56f7035

Please sign in to comment.