Skip to content

Commit

Permalink
Check padding density by input of embedding module
Browse files Browse the repository at this point in the history
  • Loading branch information
guyang3532 committed Mar 7, 2024
1 parent e93a860 commit fded099
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace {
// TODO(pengwa): remove this once customized PythonOp shape inference is supported.
constexpr const char* kInspectActivationFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation";
constexpr const char* kIncrementStepFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep";
constexpr const char* kFlagPaddingEliminationFuncName = "onnxruntime.training.ortmodule._graph_execution_manager._FlagPaddingElimination";

Check warning on line 22 in orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc:22: Lines should be <= 120 characters long [whitespace/line_length] [2]

void PushAllOutputNode(Graph& graph, std::queue<Node*>& q, Node* node, std::unordered_set<Node*>& visited) {
for (auto iter = node->OutputNodesBegin(); iter != node->OutputNodesEnd(); ++iter) {
Expand Down Expand Up @@ -311,7 +312,7 @@ void IterateSubgraphFromNode(Graph& graph,
candidate_outputs.insert(cur);
continue;
}
auto func_name = static_cast<std::string>(cur->GetAttributes().at("name").s());
auto func_name = static_cast<std::string>(cur->GetAttributes().at("func_name").s());
if (func_name == kInspectActivationFuncName || func_name == kIncrementStepFuncName) {
subgraph.insert(cur->MutableOutputDefs()[1]);
PushAllOutputNode(graph, to_visit, cur, visited);
Expand Down Expand Up @@ -353,11 +354,6 @@ void IterateSubgraphFromNode(Graph& graph,
Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
LOG_DEBUG_INFO(logger, "Enter PaddingElimination");

if (sparse_embedding_input_names_.size() == 0) {
LOG_DEBUG_INFO(logger, "Exit PaddingElimination, no sparse embedding input names.");
return Status::OK();
}

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
Node* embedding_node = nullptr;
Expand Down Expand Up @@ -386,13 +382,28 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
node.InputDefs()[2]->Exists() &&
graph_utils::IsConstantInitializer(graph, node.InputDefs()[2]->Name()) &&
node.InputDefs()[1]->Exists() &&
graph_utils::IsGraphInput(graph, node.InputDefs()[1]) &&
node.InputDefs()[1]->Shape() &&
node.InputDefs()[1]->Shape()->dim_size() >= 2) {
if (std::find(sparse_embedding_input_names_.begin(), sparse_embedding_input_names_.end(),
node.InputDefs()[1]->Name()) == sparse_embedding_input_names_.end()) {
LOG_DEBUG_INFO(logger, "Skip node " + node.Name() + "(" + node.OpType() +
") due to embedding input is not in the sparse embedding input list.");
const auto outputNodeCount = std::distance(node.OutputEdgesBegin(), node.OutputEdgesEnd());
if (outputNodeCount != 1) {
continue;
}
auto embedding_output_node = graph.GetNode(node.OutputNodesBegin()->Index());
if (embedding_output_node == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_output_node, "PythonOp", {1}, kMSDomain) ||
static_cast<std::string>(embedding_output_node->GetAttributes().at("func_name").s()) != kFlagPaddingEliminationFuncName) {

Check warning on line 394 in orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc:394: Lines should be <= 120 characters long [whitespace/line_length] [2]
LOG_DEBUG_INFO(logger, "not find PythonOp of flagPaddingElimination after embedding node");
continue;
}
if (graph_utils::CanRemoveNode(graph, *embedding_output_node, logger)) {
if (graph_utils::RemoveNode(graph, *embedding_output_node)) {
modified = true;
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_output_node->Name() + "(" + embedding_output_node->OpType() + ")");

Check warning on line 402 in orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc:402: Lines should be <= 120 characters long [whitespace/line_length] [2]
continue;
}
} else {
LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_output_node->Name() + "(" + embedding_output_node->OpType() + ")");

Check warning on line 406 in orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc:406: Lines should be <= 120 characters long [whitespace/line_length] [2]
continue;
}
const ONNX_NAMESPACE::TensorProto* padding_initializer =
Expand Down Expand Up @@ -479,7 +490,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
// to flattern the shape of [batch_size, seqlen, ...] to [valid_token_count, ...]
InsertFlattenPatternForInput(graph, *embedding_node, 1, squeeze_out_arg, logger);
handled_input_count++;
modified = true;
for (auto& node : candidate_inputs) {
for (uint32_t i = 0; i < node->InputDefs().size(); ++i) {
if (subgraph.find(node->MutableInputDefs()[i]) == subgraph.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,10 @@ namespace onnxruntime {
*/
class PaddingElimination : public GraphTransformer {
public:
PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const std::vector<std::string>& sparse_embedding_input_names = {}) noexcept
: GraphTransformer("PaddingElimination", compatible_execution_providers),
sparse_embedding_input_names_{sparse_embedding_input_names} {}
PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept

Check warning on line 130 in orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Constructors callable with one argument should be marked explicit. [runtime/explicit] [5] Raw Output: orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h:130: Constructors callable with one argument should be marked explicit. [runtime/explicit] [5]
: GraphTransformer("PaddingElimination", compatible_execution_providers) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

private:
std::vector<std::string> sparse_embedding_input_names_;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat
// Enable compute optimizer.
bool enable_compute_optimizer{false};

// Enable embedding sparsity compute optimization for the input names in the below list.
std::vector<std::string> sparse_embedding_input_names;

// Enable label sparsity compute optimization for the input names in the below list.
std::vector<std::string> sparse_label_input_names;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
#if defined(USE_CUDA) || defined(USE_ROCM)
// Put this under CUDA/ROCM guard as it depends on PadAndUnflatten CUDA/ROCM kernel.
// Once we have a CPU kernel for PadAndUnflatten, we can remove the guard.
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps,
config.sparse_embedding_input_names));
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps));
transformers.emplace_back(std::make_unique<Conv1dReplacement>(compatible_eps));
#endif
}
Expand Down
1 change: 0 additions & 1 deletion orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers)
.def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer)
.def_readwrite("sparse_embedding_input_names", &TrainingGraphTransformerConfiguration::sparse_embedding_input_names)
.def_readwrite("sparse_label_input_names", &TrainingGraphTransformerConfiguration::sparse_label_input_names)
.def_readwrite("optimized_pre_grad_filepath", &TrainingGraphTransformerConfiguration::optimized_pre_grad_filepath)
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
from abc import ABC, abstractmethod # noqa: F401
from hashlib import md5 as hash_fn
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import onnx
import torch
Expand Down Expand Up @@ -48,6 +48,28 @@ def __init__(self, state, output_info: List[Tuple[torch.Size, torch.device, torc
self.state = state
self.output_info = output_info

class _FlagPaddingElimination(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output

@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
return tensor_input_shapes, tensor_input_dtypes

@staticmethod
def alias_input(node_proto_str: str):
fw_alias_map = [0]
bw_alias_map = [0]
return fw_alias_map, bw_alias_map

class GraphExecutionManager(GraphExecutionInterface):
def __init__(
Expand Down Expand Up @@ -91,6 +113,7 @@ def __init__(
# Inspector for runtime information, for example input data, memory usage, etc.
self._runtime_inspector = RuntimeInspector(self._logger, self._original_module)
self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step)
self._embedding_module_to_padding_density_map = {}

# Tracker for ORTModule model export, session creation overhead.
self.time_tracker = _logger.TimeTracker()
Expand Down Expand Up @@ -622,6 +645,38 @@ def __setstate__(self, state):

_utils.reinitialize_graph_execution_manager(self)

def _check_embedding_sparsity(self):
if not self._runtime_options.enable_embedding_sparse_optimizer or self._device.type != "cuda":
return
def embedding_hook(module, args, output):
ebd_input = args[0]
if ebd_input is None or not isinstance(ebd_input, torch.Tensor):
self._logger.warning("Embedding input is not a tensor.")
return None

valid_token = torch.count_nonzero(ebd_input - module.padding_idx)
total_token = ebd_input.numel()
embed_density = float(valid_token) / float(total_token) * 100
if embed_density < 90:
self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density)
if module not in self._embedding_module_to_padding_density_map:
self._logger.warning("Found Embedding module not in the map. %s", module)
return None
if module in self._embedding_module_to_padding_density_map and self._embedding_module_to_padding_density_map[module][1] != -1:
self._logger.warning(
"Found duplicate Embedding module. %s",
self._embedding_module_to_padding_density_map[module][0]
)
self._embedding_module_to_padding_density_map[module][1] = embed_density
return _FlagPaddingElimination.apply(output)
return None

for name, sub_module in self._flattened_module.named_modules():
if isinstance(sub_module, torch.nn.modules.sparse.Embedding):
if sub_module.padding_idx >= 0:
self._embedding_module_to_padding_density_map[sub_module] = [name, -1]
sub_module.register_forward_hook(embedding_hook)

@_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION)
def _enable_conditional_optimizations(
self, graph_transformer_config: C.TrainingGraphTransformerConfiguration, inputs: Tuple, kwargs: Dict
Expand Down Expand Up @@ -680,16 +735,10 @@ def _enable_conditional_optimizations(
[f"{k}:{v:.0f}%" for k, v in label_sparsity_results.items()]
)

if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0:
if detected_device.type == "cuda":
# Embedding sparsity optimization is only supported on CUDA devices.
graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys())
self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results)
self._runtime_options.embed_sparsity_ratio = ",".join(
[f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()]
)
else:
self._logger.info("Embedding sparsity-based optimization is not supported on non-CUDA devices.")
if self._embedding_module_to_padding_density_map:
self._runtime_options.embed_sparsity_ratio = ",".join(
[f"{v[0]}:{v[1]:.0f}%" for v in self._embedding_module_to_padding_density_map.values()]
)

# If users don't want to print input density, disable the input density observer to avoid overhead
# when looping through inputs during training.
Expand Down
4 changes: 1 addition & 3 deletions orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,8 @@ def _expand_inputs(current_input, non_none_inputs, name=""):
if PrimitiveType.is_primitive_type(inp):
inp = PrimitiveType.get_tensor(inp, device)

found, embedding_density, label_density = rt_inspector.inspect_input(name, inp)
found, _, label_density = rt_inspector.inspect_input(name, inp)
if found:
if embedding_density < 100:
embed_sparsity_results[name] = embedding_density
if label_density < 100:
label_sparsity_results[name] = label_density
result.append(inp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def initialize(self, model: ModelProto, user_input_names: List[str]) -> None:
if output_name != "":
self._tensor_to_node_map[output_name] = node

self._initialize_embedding_padding_inspector(model, user_input_names)
self._initialize_loss_label_padding_inspector(model, user_input_names)

self._is_initialized = True
Expand Down Expand Up @@ -354,33 +353,6 @@ def _inspect_embed_label_input(self, name, data):
found = False
min_embed_density = 100
min_label_density = 100
if (
len(self._embedding_graph_input_to_padding_idx_map) > 0
and name in self._embedding_graph_input_to_padding_idx_map
and isinstance(data, torch.Tensor)
):
for padding_idx in self._embedding_graph_input_to_padding_idx_map[name]:
valid_token = torch.count_nonzero(data - padding_idx)
valid_token_per_batch = "N/A"
if data.dim() > 1:
valid_token_per_batch = str(torch.count_nonzero(data - padding_idx, dim=1).tolist())
total_token = data.numel()
embed_density = float(valid_token) / float(total_token) * 100
if embed_density < 90:
min_embed_density = min(min_embed_density, embed_density)
self._stats.append(
[
self._current_step,
"EMBED",
name,
padding_idx,
embed_density,
valid_token,
total_token,
valid_token_per_batch,
]
)
found = True

if (
len(self._loss_label_graph_input_to_ignore_idx_map) > 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def forward(self, *inputs, **kwargs):
):
self.time_tracker.start(ORTModuleInitPhase.EndToEnd)

self._check_embedding_sparsity()

build_gradient_graph = self._export_model(*inputs, **kwargs)

if build_gradient_graph:
Expand Down

0 comments on commit fded099

Please sign in to comment.