diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index d42af92c7c66d..cc997df967bb5 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -19,6 +19,7 @@ namespace { // TODO(pengwa): remove this once customized PythonOp shape inference is supported. constexpr const char* kInspectActivationFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation"; constexpr const char* kIncrementStepFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep"; +constexpr const char* kFlagPaddingEliminationFuncName = "onnxruntime.training.ortmodule._graph_execution_manager._FlagPaddingElimination"; void PushAllOutputNode(Graph& graph, std::queue& q, Node* node, std::unordered_set& visited) { for (auto iter = node->OutputNodesBegin(); iter != node->OutputNodesEnd(); ++iter) { @@ -311,7 +312,7 @@ void IterateSubgraphFromNode(Graph& graph, candidate_outputs.insert(cur); continue; } - auto func_name = static_cast(cur->GetAttributes().at("name").s()); + auto func_name = static_cast(cur->GetAttributes().at("func_name").s()); if (func_name == kInspectActivationFuncName || func_name == kIncrementStepFuncName) { subgraph.insert(cur->MutableOutputDefs()[1]); PushAllOutputNode(graph, to_visit, cur, visited); @@ -353,11 +354,6 @@ void IterateSubgraphFromNode(Graph& graph, Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { LOG_DEBUG_INFO(logger, "Enter PaddingElimination"); - if (sparse_embedding_input_names_.size() == 0) { - LOG_DEBUG_INFO(logger, "Exit PaddingElimination, no sparse embedding input names."); - return Status::OK(); - } - GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); Node* embedding_node = nullptr; @@ -386,13 +382,28 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev node.InputDefs()[2]->Exists() && graph_utils::IsConstantInitializer(graph, node.InputDefs()[2]->Name()) && node.InputDefs()[1]->Exists() && - graph_utils::IsGraphInput(graph, node.InputDefs()[1]) && node.InputDefs()[1]->Shape() && node.InputDefs()[1]->Shape()->dim_size() >= 2) { - if (std::find(sparse_embedding_input_names_.begin(), sparse_embedding_input_names_.end(), - node.InputDefs()[1]->Name()) == sparse_embedding_input_names_.end()) { - LOG_DEBUG_INFO(logger, "Skip node " + node.Name() + "(" + node.OpType() + - ") due to embedding input is not in the sparse embedding input list."); + const auto outputNodeCount = std::distance(node.OutputEdgesBegin(), node.OutputEdgesEnd()); + if (outputNodeCount != 1) { + continue; + } + auto embedding_output_node = graph.GetNode(node.OutputNodesBegin()->Index()); + if (embedding_output_node == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_output_node, "PythonOp", {1}, kMSDomain) || + static_cast(embedding_output_node->GetAttributes().at("func_name").s()) != kFlagPaddingEliminationFuncName) { + LOG_DEBUG_INFO(logger, "not find PythonOp of flagPaddingElimination after embedding node"); + continue; + } + if (graph_utils::CanRemoveNode(graph, *embedding_output_node, logger)) { + if (graph_utils::RemoveNode(graph, *embedding_output_node)) { + modified = true; + } else { + LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_output_node->Name() + "(" + embedding_output_node->OpType() + ")"); + continue; + } + } else { + LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_output_node->Name() + "(" + embedding_output_node->OpType() + ")"); continue; } const ONNX_NAMESPACE::TensorProto* padding_initializer = @@ -479,7 +490,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev // to flattern the shape of [batch_size, seqlen, ...] to [valid_token_count, ...] InsertFlattenPatternForInput(graph, *embedding_node, 1, squeeze_out_arg, logger); handled_input_count++; - modified = true; for (auto& node : candidate_inputs) { for (uint32_t i = 0; i < node->InputDefs().size(); ++i) { if (subgraph.find(node->MutableInputDefs()[i]) == subgraph.end()) { diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h index c4f283c30fddc..607e059d54f82 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h @@ -127,15 +127,10 @@ namespace onnxruntime { */ class PaddingElimination : public GraphTransformer { public: - PaddingElimination(const InlinedHashSet& compatible_execution_providers = {}, - const std::vector& sparse_embedding_input_names = {}) noexcept - : GraphTransformer("PaddingElimination", compatible_execution_providers), - sparse_embedding_input_names_{sparse_embedding_input_names} {} + PaddingElimination(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("PaddingElimination", compatible_execution_providers) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; - - private: - std::vector sparse_embedding_input_names_; }; } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_config.h b/orttraining/orttraining/core/optimizer/graph_transformer_config.h index cc3edfb016a15..f6c14503978e1 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_config.h +++ b/orttraining/orttraining/core/optimizer/graph_transformer_config.h @@ -25,9 +25,6 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat // Enable compute optimizer. bool enable_compute_optimizer{false}; - // Enable embedding sparsity compute optimization for the input names in the below list. - std::vector sparse_embedding_input_names; - // Enable label sparsity compute optimization for the input names in the below list. std::vector sparse_label_input_names; diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 5d527369a1b75..0c8ef7d9d69da 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -195,8 +195,7 @@ std::vector> GeneratePreTrainingTransformers( #if defined(USE_CUDA) || defined(USE_ROCM) // Put this under CUDA/ROCM guard as it depends on PadAndUnflatten CUDA/ROCM kernel. // Once we have a CPU kernel for PadAndUnflatten, we can remove the guard. - transformers.emplace_back(std::make_unique(compatible_eps, - config.sparse_embedding_input_names)); + transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); #endif } diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 4ab8db8565bf9..2d2a3db1be2f9 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -488,7 +488,6 @@ void addObjectMethodsForTraining(py::module& m) { .def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute) .def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers) .def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer) - .def_readwrite("sparse_embedding_input_names", &TrainingGraphTransformerConfiguration::sparse_embedding_input_names) .def_readwrite("sparse_label_input_names", &TrainingGraphTransformerConfiguration::sparse_label_input_names) .def_readwrite("optimized_pre_grad_filepath", &TrainingGraphTransformerConfiguration::optimized_pre_grad_filepath) .def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config); diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index c67b05758c5aa..b5480eabea19f 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -10,7 +10,7 @@ import os from abc import ABC, abstractmethod # noqa: F401 from hashlib import md5 as hash_fn -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import onnx import torch @@ -48,6 +48,28 @@ def __init__(self, state, output_info: List[Tuple[torch.Size, torch.device, torc self.state = state self.output_info = output_info +class _FlagPaddingElimination(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + return grad_output + + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes + + @staticmethod + def alias_input(node_proto_str: str): + fw_alias_map = [0] + bw_alias_map = [0] + return fw_alias_map, bw_alias_map class GraphExecutionManager(GraphExecutionInterface): def __init__( @@ -91,6 +113,7 @@ def __init__( # Inspector for runtime information, for example input data, memory usage, etc. self._runtime_inspector = RuntimeInspector(self._logger, self._original_module) self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step) + self._embedding_module_to_padding_density_map = {} # Tracker for ORTModule model export, session creation overhead. self.time_tracker = _logger.TimeTracker() @@ -622,6 +645,38 @@ def __setstate__(self, state): _utils.reinitialize_graph_execution_manager(self) + def _check_embedding_sparsity(self): + if not self._runtime_options.enable_embedding_sparse_optimizer or self._device.type != "cuda": + return + def embedding_hook(module, args, output): + ebd_input = args[0] + if ebd_input is None or not isinstance(ebd_input, torch.Tensor): + self._logger.warning("Embedding input is not a tensor.") + return None + + valid_token = torch.count_nonzero(ebd_input - module.padding_idx) + total_token = ebd_input.numel() + embed_density = float(valid_token) / float(total_token) * 100 + if embed_density < 90: + self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density) + if module not in self._embedding_module_to_padding_density_map: + self._logger.warning("Found Embedding module not in the map. %s", module) + return None + if module in self._embedding_module_to_padding_density_map and self._embedding_module_to_padding_density_map[module][1] != -1: + self._logger.warning( + "Found duplicate Embedding module. %s", + self._embedding_module_to_padding_density_map[module][0] + ) + self._embedding_module_to_padding_density_map[module][1] = embed_density + return _FlagPaddingElimination.apply(output) + return None + + for name, sub_module in self._flattened_module.named_modules(): + if isinstance(sub_module, torch.nn.modules.sparse.Embedding): + if sub_module.padding_idx >= 0: + self._embedding_module_to_padding_density_map[sub_module] = [name, -1] + sub_module.register_forward_hook(embedding_hook) + @_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION) def _enable_conditional_optimizations( self, graph_transformer_config: C.TrainingGraphTransformerConfiguration, inputs: Tuple, kwargs: Dict @@ -680,16 +735,10 @@ def _enable_conditional_optimizations( [f"{k}:{v:.0f}%" for k, v in label_sparsity_results.items()] ) - if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0: - if detected_device.type == "cuda": - # Embedding sparsity optimization is only supported on CUDA devices. - graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys()) - self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results) - self._runtime_options.embed_sparsity_ratio = ",".join( - [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()] - ) - else: - self._logger.info("Embedding sparsity-based optimization is not supported on non-CUDA devices.") + if self._embedding_module_to_padding_density_map: + self._runtime_options.embed_sparsity_ratio = ",".join( + [f"{v[0]}:{v[1]:.0f}%" for v in self._embedding_module_to_padding_density_map.values()] + ) # If users don't want to print input density, disable the input density observer to avoid overhead # when looping through inputs during training. diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 7534cc46a21f1..e1b2620fa6fef 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -245,10 +245,8 @@ def _expand_inputs(current_input, non_none_inputs, name=""): if PrimitiveType.is_primitive_type(inp): inp = PrimitiveType.get_tensor(inp, device) - found, embedding_density, label_density = rt_inspector.inspect_input(name, inp) + found, _, label_density = rt_inspector.inspect_input(name, inp) if found: - if embedding_density < 100: - embed_sparsity_results[name] = embedding_density if label_density < 100: label_sparsity_results[name] = label_density result.append(inp) diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 22e31466887a6..56ade5f1f194c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -136,7 +136,6 @@ def initialize(self, model: ModelProto, user_input_names: List[str]) -> None: if output_name != "": self._tensor_to_node_map[output_name] = node - self._initialize_embedding_padding_inspector(model, user_input_names) self._initialize_loss_label_padding_inspector(model, user_input_names) self._is_initialized = True @@ -354,33 +353,6 @@ def _inspect_embed_label_input(self, name, data): found = False min_embed_density = 100 min_label_density = 100 - if ( - len(self._embedding_graph_input_to_padding_idx_map) > 0 - and name in self._embedding_graph_input_to_padding_idx_map - and isinstance(data, torch.Tensor) - ): - for padding_idx in self._embedding_graph_input_to_padding_idx_map[name]: - valid_token = torch.count_nonzero(data - padding_idx) - valid_token_per_batch = "N/A" - if data.dim() > 1: - valid_token_per_batch = str(torch.count_nonzero(data - padding_idx, dim=1).tolist()) - total_token = data.numel() - embed_density = float(valid_token) / float(total_token) * 100 - if embed_density < 90: - min_embed_density = min(min_embed_density, embed_density) - self._stats.append( - [ - self._current_step, - "EMBED", - name, - padding_idx, - embed_density, - valid_token, - total_token, - valid_token_per_batch, - ] - ) - found = True if ( len(self._loss_label_graph_input_to_ignore_idx_map) > 0 diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 73c32a2f51e41..d3fef3d322579 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -253,6 +253,8 @@ def forward(self, *inputs, **kwargs): ): self.time_tracker.start(ORTModuleInitPhase.EndToEnd) + self._check_embedding_sparsity() + build_gradient_graph = self._export_model(*inputs, **kwargs) if build_gradient_graph: