Skip to content

Commit

Permalink
Check padding density by input of embedding module (#19821)
Browse files Browse the repository at this point in the history
### Description
The PaddingElimination optimization is enabled when the density of
embedding padding less than 90%. We need to check the density of the
embedding padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with
the embedding node by iterate graph from the embedding node back to one
graph input.
This is hard to be general because there may be complicated pattern
between graph input and embedding node.

This pr check padding density by the direct input of embedding module
rather than the input of graph at the first graph execution when
exporting onnx graph.
And if the density < 90%, insert a flag PythonOp after the embedding
node as:
```
             Embedding
		  |
            PythonOp (func_name:_FlagPaddingElimination)   (insert if density < 90%)
		  |
            Following graph
```

When the PaddingElimination is invoked, it check if there is the flag
PythonOp(func_name:_FlagPaddingElimination) after the Embedding node and
if it is, remove it and do the padding elimination optimization.
  • Loading branch information
guyang3532 authored Apr 10, 2024
1 parent 0acde11 commit 471e969
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 40 deletions.
2 changes: 1 addition & 1 deletion orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using namespace ONNX_NAMESPACE;
namespace {

// TODO(pengwa): remove this once customized PythonOp shape inference is supported.
constexpr const char* kInspectActivationFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation";
constexpr const char* kInspectActivationFuncName = "onnxruntime.training.utils.hooks._statistics_subscriber._InspectActivation";
constexpr const char* kIncrementStepFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep";

std::array<TensorShapeProto::Dimension, 6> GetRNNDimensions(InferenceContext& ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ namespace onnxruntime {
namespace {

// TODO(pengwa): remove this once customized PythonOp shape inference is supported.
constexpr const char* kInspectActivationFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation";
constexpr const char* kIncrementStepFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep";
constexpr const char* kInspectActivationFuncName =
"onnxruntime.training.utils.hooks._statistics_subscriber._InspectActivation";
constexpr const char* kIncrementStepFuncName =
"onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep";
constexpr const char* kFlagPaddingEliminationFuncName =
"onnxruntime.training.ortmodule._runtime_inspector.FlagPaddingElimination";

void PushAllOutputNode(Graph& graph, std::queue<Node*>& q, Node* node, std::unordered_set<Node*>& visited) {
for (auto iter = node->OutputNodesBegin(); iter != node->OutputNodesEnd(); ++iter) {
Expand Down Expand Up @@ -316,7 +320,7 @@ void IterateSubgraphFromNode(Graph& graph,
candidate_outputs.insert(cur);
continue;
}
auto func_name = static_cast<std::string>(cur->GetAttributes().at("name").s());
auto func_name = static_cast<std::string>(cur->GetAttributes().at("func_name").s());
if (func_name == kInspectActivationFuncName || func_name == kIncrementStepFuncName) {
subgraph.insert(cur->MutableOutputDefs()[1]);
PushAllOutputNode(graph, to_visit, cur, visited);
Expand Down Expand Up @@ -358,11 +362,6 @@ void IterateSubgraphFromNode(Graph& graph,
Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
LOG_DEBUG_INFO(logger, "Enter PaddingElimination");

if (sparse_embedding_input_names_.size() == 0) {
LOG_DEBUG_INFO(logger, "Exit PaddingElimination, no sparse embedding input names.");
return Status::OK();
}

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
Node* embedding_node = nullptr;
Expand Down Expand Up @@ -391,13 +390,31 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
node.InputDefs()[2]->Exists() &&
graph_utils::IsConstantInitializer(graph, node.InputDefs()[2]->Name()) &&
node.InputDefs()[1]->Exists() &&
graph_utils::IsGraphInput(graph, node.InputDefs()[1]) &&
node.InputDefs()[1]->Shape() &&
node.InputDefs()[1]->Shape()->dim_size() >= 2) {
if (std::find(sparse_embedding_input_names_.begin(), sparse_embedding_input_names_.end(),
node.InputDefs()[1]->Name()) == sparse_embedding_input_names_.end()) {
LOG_DEBUG_INFO(logger, "Skip node " + node.Name() + "(" + node.OpType() +
") due to embedding input is not in the sparse embedding input list.");
const auto outputNodeCount = std::distance(node.OutputEdgesBegin(), node.OutputEdgesEnd());
if (outputNodeCount != 1) {
continue;
}
auto embedding_output_node = graph.GetNode(node.OutputNodesBegin()->Index());
if (embedding_output_node == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_output_node, "PythonOp", {1}, kMSDomain) ||
static_cast<std::string>(embedding_output_node->GetAttributes().at("func_name").s()) !=
kFlagPaddingEliminationFuncName) {
LOG_DEBUG_INFO(logger, "not find PythonOp of flagPaddingElimination after embedding node");
continue;
}
if (graph_utils::CanRemoveNode(graph, *embedding_output_node, logger)) {
if (graph_utils::RemoveNode(graph, *embedding_output_node)) {
modified = true;
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_output_node->Name() +
"(" + embedding_output_node->OpType() + ")");
continue;
}
} else {
LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_output_node->Name() +
"(" + embedding_output_node->OpType() + ")");
continue;
}
const ONNX_NAMESPACE::TensorProto* padding_initializer =
Expand Down Expand Up @@ -484,7 +501,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
// to flattern the shape of [batch_size, seqlen, ...] to [valid_token_count, ...]
InsertFlattenPatternForInput(graph, *embedding_node, 1, squeeze_out_arg, logger);
handled_input_count++;
modified = true;
for (auto& node : candidate_inputs) {
for (uint32_t i = 0; i < node->InputDefs().size(); ++i) {
if (subgraph.find(node->MutableInputDefs()[i]) == subgraph.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,10 @@ namespace onnxruntime {
*/
class PaddingElimination : public GraphTransformer {
public:
PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const std::vector<std::string>& sparse_embedding_input_names = {}) noexcept
: GraphTransformer("PaddingElimination", compatible_execution_providers),
sparse_embedding_input_names_{sparse_embedding_input_names} {}
explicit PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("PaddingElimination", compatible_execution_providers) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

private:
std::vector<std::string> sparse_embedding_input_names_;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat
// Enable compute optimizer.
bool enable_compute_optimizer{false};

// Enable embedding sparsity compute optimization for the input names in the below list.
std::vector<std::string> sparse_embedding_input_names;

// Enable label sparsity compute optimization for the input names in the below list.
std::vector<std::string> sparse_label_input_names;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
#if defined(USE_CUDA) || defined(USE_ROCM)
// Put this under CUDA/ROCM guard as it depends on PadAndUnflatten CUDA/ROCM kernel.
// Once we have a CPU kernel for PadAndUnflatten, we can remove the guard.
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps,
config.sparse_embedding_input_names));
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps));
transformers.emplace_back(std::make_unique<Conv1dReplacement>(compatible_eps));
#endif
}
Expand Down
1 change: 0 additions & 1 deletion orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers)
.def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer)
.def_readwrite("sparse_embedding_input_names", &TrainingGraphTransformerConfiguration::sparse_embedding_input_names)
.def_readwrite("sparse_label_input_names", &TrainingGraphTransformerConfiguration::sparse_label_input_names)
.def_readwrite("optimized_pre_grad_filepath", &TrainingGraphTransformerConfiguration::optimized_pre_grad_filepath)
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ._graph_execution_interface import GraphExecutionInterface
from ._io import _FlattenedModule, _InputInfo
from ._logger import LogColor
from ._runtime_inspector import RuntimeInspector
from ._runtime_inspector import FlagPaddingElimination, RuntimeInspector
from ._utils import check_function_has_param, get_rank
from .options import DebugOptions, LogLevel, _MemoryOptimizationLevel, _RuntimeOptions
from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension
Expand Down Expand Up @@ -306,11 +306,16 @@ def _export_model(self, *inputs, **kwargs) -> bool:
# All required models have already been exported previously
return False
self._set_device_from_module(inputs, kwargs)
# TODO: move it into runtime_inspector
embedding_hook_handles = self._add_check_embedding_sparsity_hook()

from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step

with export_context(), no_increase_global_step():
self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs)

for hook in embedding_hook_handles:
hook.remove()
if self._debug_options.save_onnx_models.save:
self._onnx_models.save_exported_model(
self._debug_options.save_onnx_models.path,
Expand Down Expand Up @@ -671,6 +676,58 @@ def __setstate__(self, state):

_utils.reinitialize_graph_execution_manager(self)

def _add_check_embedding_sparsity_hook(self):
"""
Add hook to check embedding sparsity and enable padding elimination if applicable.
1. Iterate through all modules to find Embedding modules with padding_idx >= 0.
2. Register forward hook to the Embedding module and the hook will check sparsity of the embedding input.
3. If the sparsity is below a threshold, enable padding elimination by adding FlagPaddingElimination after the
output. GraphTransformer of PaddingElimination will check the FlagPaddingElimination and do the actual
padding elimination graph modification.
4. Return the hook handles for later removal.
"""
if (
not self._runtime_options.enable_sparse_optimizer
or not self._runtime_options.enable_embedding_sparse_optimizer
or self._device.type != "cuda"
):
return []

def _embedding_hook(module, args, output):
ebd_input = args[0]
if ebd_input is None or not isinstance(ebd_input, torch.Tensor):
self._logger.warning("Embedding input is not a tensor.")
return None

valid_token = torch.count_nonzero(ebd_input - module.padding_idx)
total_token = ebd_input.numel()
embed_density = float(valid_token) / float(total_token) * 100
if embed_density < 90:
self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density)
if module not in self._runtime_inspector._embedding_module_to_padding_density_map:
self._logger.warning("Found Embedding module not in the map. %s", module)
return None
if self._runtime_inspector._embedding_module_to_padding_density_map[module][1] != -1:
self._logger.warning(
"Found duplicate Embedding module. %s",
self._runtime_inspector._embedding_module_to_padding_density_map[module][0],
)
self._runtime_inspector._embedding_module_to_padding_density_map[module][1] = embed_density
return FlagPaddingElimination.apply(output)
else:
self._logger.info("Embedding sparsity-based optimization is OFF for density: %.0f%%", embed_density)
return None

embedding_hook_handles = []
for name, sub_module in self._flattened_module.named_modules():
if isinstance(sub_module, torch.nn.modules.sparse.Embedding):
if sub_module.padding_idx is not None and sub_module.padding_idx >= 0:
self._runtime_inspector._embedding_module_to_padding_density_map[sub_module] = [name, -1]
embedding_hook_handles.append(sub_module.register_forward_hook(_embedding_hook))

return embedding_hook_handles

@_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION)
def _enable_conditional_optimizations(
self, graph_transformer_config: C.TrainingGraphTransformerConfiguration, inputs: Tuple, kwargs: Dict
Expand Down Expand Up @@ -709,7 +766,7 @@ def _enable_conditional_optimizations(
else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers

_, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers(
_, _, label_sparsity_results = _io._combine_input_buffers_initializers(
param_to_append_as_onnx_graph_inputs,
self._graph_builder.get_graph_info().user_input_names,
self._input_info,
Expand All @@ -729,16 +786,13 @@ def _enable_conditional_optimizations(
[f"{k}:{v:.0f}%" for k, v in label_sparsity_results.items()]
)

if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0:
if detected_device.type == "cuda":
# Embedding sparsity optimization is only supported on CUDA devices.
graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys())
self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results)
self._runtime_options.embed_sparsity_ratio = ",".join(
[f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()]
)
else:
self._logger.info("Embedding sparsity-based optimization is not supported on non-CUDA devices.")
if self._runtime_inspector._embedding_module_to_padding_density_map:
self._runtime_options.embed_sparsity_ratio = ",".join(
[
f"{v[0]}:{v[1]:.0f}%"
for v in self._runtime_inspector._embedding_module_to_padding_density_map.values()
]
)

# If users don't want to print input density, disable the input density observer to avoid overhead
# when looping through inputs during training.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, logger: Logger, module: torch.nn.Module, training: bool):

self.input_density_ob: Union[InputDensityObserver, None] = None
self.memory_ob = MemoryObserver(module, self._logger, training)
self._embedding_module_to_padding_density_map = {}

def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None:
"""Initialize input inspector from the given ONNX model and user input names.
Expand Down Expand Up @@ -747,3 +748,33 @@ def _get_user_config_without_freq(configs: str):
return notes, mem_tbl

return [], None


class FlagPaddingElimination(torch.autograd.Function):
"""
FlagPaddingElimination is a PyTorch autograd function that does nothing in forward pass and backward pass.
It is used as a flag to tell the GraphTransformer of PaddingElimination to modify the graph to eliminate
the embedding padding.
"""

@staticmethod
def forward(ctx, input):
return input

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output

@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
return tensor_input_shapes, tensor_input_dtypes

@staticmethod
def alias_input(node_proto_str: str):
fw_alias_map = [0]
bw_alias_map = [0]
return fw_alias_map, bw_alias_map

0 comments on commit 471e969

Please sign in to comment.