Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check padding density by input of embedding module #19821

Merged
merged 6 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
namespace {

// TODO(pengwa): remove this once customized PythonOp shape inference is supported.
constexpr const char* kInspectActivationFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation";
constexpr const char* kInspectActivationFuncName = "onnxruntime.training.utils.hooks._statistics_subscriber._InspectActivation";

Check warning on line 23 in orttraining/orttraining/core/graph/training_op_defs.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: orttraining/orttraining/core/graph/training_op_defs.cc:23: Lines should be <= 120 characters long [whitespace/line_length] [2]
constexpr const char* kIncrementStepFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep";

std::array<TensorShapeProto::Dimension, 6> GetRNNDimensions(InferenceContext& ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ namespace onnxruntime {
namespace {

// TODO(pengwa): remove this once customized PythonOp shape inference is supported.
constexpr const char* kInspectActivationFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation";
constexpr const char* kIncrementStepFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep";
constexpr const char* kInspectActivationFuncName =
"onnxruntime.training.utils.hooks._statistics_subscriber._InspectActivation";
constexpr const char* kIncrementStepFuncName =
"onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep";
guyang3532 marked this conversation as resolved.
Show resolved Hide resolved
constexpr const char* kFlagPaddingEliminationFuncName =
"onnxruntime.training.ortmodule._runtime_inspector.FlagPaddingElimination";

void PushAllOutputNode(Graph& graph, std::queue<Node*>& q, Node* node, std::unordered_set<Node*>& visited) {
for (auto iter = node->OutputNodesBegin(); iter != node->OutputNodesEnd(); ++iter) {
Expand Down Expand Up @@ -316,7 +320,7 @@ void IterateSubgraphFromNode(Graph& graph,
candidate_outputs.insert(cur);
continue;
}
auto func_name = static_cast<std::string>(cur->GetAttributes().at("name").s());
auto func_name = static_cast<std::string>(cur->GetAttributes().at("func_name").s());
if (func_name == kInspectActivationFuncName || func_name == kIncrementStepFuncName) {
subgraph.insert(cur->MutableOutputDefs()[1]);
PushAllOutputNode(graph, to_visit, cur, visited);
Expand Down Expand Up @@ -358,11 +362,6 @@ void IterateSubgraphFromNode(Graph& graph,
Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
LOG_DEBUG_INFO(logger, "Enter PaddingElimination");

if (sparse_embedding_input_names_.size() == 0) {
LOG_DEBUG_INFO(logger, "Exit PaddingElimination, no sparse embedding input names.");
return Status::OK();
}

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
Node* embedding_node = nullptr;
Expand Down Expand Up @@ -391,13 +390,31 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
node.InputDefs()[2]->Exists() &&
graph_utils::IsConstantInitializer(graph, node.InputDefs()[2]->Name()) &&
node.InputDefs()[1]->Exists() &&
graph_utils::IsGraphInput(graph, node.InputDefs()[1]) &&
node.InputDefs()[1]->Shape() &&
node.InputDefs()[1]->Shape()->dim_size() >= 2) {
if (std::find(sparse_embedding_input_names_.begin(), sparse_embedding_input_names_.end(),
node.InputDefs()[1]->Name()) == sparse_embedding_input_names_.end()) {
LOG_DEBUG_INFO(logger, "Skip node " + node.Name() + "(" + node.OpType() +
") due to embedding input is not in the sparse embedding input list.");
const auto outputNodeCount = std::distance(node.OutputEdgesBegin(), node.OutputEdgesEnd());
if (outputNodeCount != 1) {
continue;
}
auto embedding_output_node = graph.GetNode(node.OutputNodesBegin()->Index());
if (embedding_output_node == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_output_node, "PythonOp", {1}, kMSDomain) ||
static_cast<std::string>(embedding_output_node->GetAttributes().at("func_name").s()) !=
kFlagPaddingEliminationFuncName) {
LOG_DEBUG_INFO(logger, "not find PythonOp of flagPaddingElimination after embedding node");
continue;
}
if (graph_utils::CanRemoveNode(graph, *embedding_output_node, logger)) {
if (graph_utils::RemoveNode(graph, *embedding_output_node)) {
modified = true;
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_output_node->Name() +
"(" + embedding_output_node->OpType() + ")");
continue;
}
} else {
LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_output_node->Name() +
"(" + embedding_output_node->OpType() + ")");
continue;
}
const ONNX_NAMESPACE::TensorProto* padding_initializer =
Expand Down Expand Up @@ -484,7 +501,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
// to flattern the shape of [batch_size, seqlen, ...] to [valid_token_count, ...]
InsertFlattenPatternForInput(graph, *embedding_node, 1, squeeze_out_arg, logger);
handled_input_count++;
modified = true;
for (auto& node : candidate_inputs) {
for (uint32_t i = 0; i < node->InputDefs().size(); ++i) {
if (subgraph.find(node->MutableInputDefs()[i]) == subgraph.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,10 @@ namespace onnxruntime {
*/
class PaddingElimination : public GraphTransformer {
public:
PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const std::vector<std::string>& sparse_embedding_input_names = {}) noexcept
: GraphTransformer("PaddingElimination", compatible_execution_providers),
sparse_embedding_input_names_{sparse_embedding_input_names} {}
explicit PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("PaddingElimination", compatible_execution_providers) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

private:
std::vector<std::string> sparse_embedding_input_names_;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat
// Enable compute optimizer.
bool enable_compute_optimizer{false};

// Enable embedding sparsity compute optimization for the input names in the below list.
std::vector<std::string> sparse_embedding_input_names;

// Enable label sparsity compute optimization for the input names in the below list.
std::vector<std::string> sparse_label_input_names;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
#if defined(USE_CUDA) || defined(USE_ROCM)
// Put this under CUDA/ROCM guard as it depends on PadAndUnflatten CUDA/ROCM kernel.
// Once we have a CPU kernel for PadAndUnflatten, we can remove the guard.
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps,
config.sparse_embedding_input_names));
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps));
transformers.emplace_back(std::make_unique<Conv1dReplacement>(compatible_eps));
#endif
}
Expand Down
1 change: 0 additions & 1 deletion orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers)
.def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer)
.def_readwrite("sparse_embedding_input_names", &TrainingGraphTransformerConfiguration::sparse_embedding_input_names)
.def_readwrite("sparse_label_input_names", &TrainingGraphTransformerConfiguration::sparse_label_input_names)
.def_readwrite("optimized_pre_grad_filepath", &TrainingGraphTransformerConfiguration::optimized_pre_grad_filepath)
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ._graph_execution_interface import GraphExecutionInterface
from ._io import _FlattenedModule, _InputInfo
from ._logger import LogColor
from ._runtime_inspector import RuntimeInspector
from ._runtime_inspector import FlagPaddingElimination, RuntimeInspector
from ._utils import check_function_has_param, get_rank
from .options import DebugOptions, LogLevel, _MemoryOptimizationLevel, _RuntimeOptions
from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension
Expand Down Expand Up @@ -306,11 +306,16 @@ def _export_model(self, *inputs, **kwargs) -> bool:
# All required models have already been exported previously
return False
self._set_device_from_module(inputs, kwargs)
# TODO: move it into runtime_inspector
embedding_hook_handles = self._add_check_embedding_sparsity_hook()

from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step

with export_context(), no_increase_global_step():
self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs)

for hook in embedding_hook_handles:
hook.remove()
if self._debug_options.save_onnx_models.save:
self._onnx_models.save_exported_model(
self._debug_options.save_onnx_models.path,
Expand Down Expand Up @@ -671,6 +676,58 @@ def __setstate__(self, state):

_utils.reinitialize_graph_execution_manager(self)

def _add_check_embedding_sparsity_hook(self):
"""
Add hook to check embedding sparsity and enable padding elimination if applicable.
1. Iterate through all modules to find Embedding modules with padding_idx >= 0.
2. Register forward hook to the Embedding module and the hook will check sparsity of the embedding input.
3. If the sparsity is below a threshold, enable padding elimination by adding FlagPaddingElimination after the
output. GraphTransformer of PaddingElimination will check the FlagPaddingElimination and do the actual
padding elimination graph modification.
4. Return the hook handles for later removal.

"""
if (
pengwa marked this conversation as resolved.
Show resolved Hide resolved
not self._runtime_options.enable_sparse_optimizer
or not self._runtime_options.enable_embedding_sparse_optimizer
or self._device.type != "cuda"
):
return []

def _embedding_hook(module, args, output):
ebd_input = args[0]
if ebd_input is None or not isinstance(ebd_input, torch.Tensor):
self._logger.warning("Embedding input is not a tensor.")
return None

valid_token = torch.count_nonzero(ebd_input - module.padding_idx)
total_token = ebd_input.numel()
embed_density = float(valid_token) / float(total_token) * 100
if embed_density < 90:
self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density)
if module not in self._runtime_inspector._embedding_module_to_padding_density_map:
self._logger.warning("Found Embedding module not in the map. %s", module)
return None
if self._runtime_inspector._embedding_module_to_padding_density_map[module][1] != -1:
self._logger.warning(
"Found duplicate Embedding module. %s",
self._runtime_inspector._embedding_module_to_padding_density_map[module][0],
)
self._runtime_inspector._embedding_module_to_padding_density_map[module][1] = embed_density
return FlagPaddingElimination.apply(output)
else:
self._logger.info("Embedding sparsity-based optimization is OFF for density: %.0f%%", embed_density)
return None

embedding_hook_handles = []
for name, sub_module in self._flattened_module.named_modules():
if isinstance(sub_module, torch.nn.modules.sparse.Embedding):
if sub_module.padding_idx is not None and sub_module.padding_idx >= 0:
self._runtime_inspector._embedding_module_to_padding_density_map[sub_module] = [name, -1]
embedding_hook_handles.append(sub_module.register_forward_hook(_embedding_hook))

return embedding_hook_handles

@_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION)
def _enable_conditional_optimizations(
self, graph_transformer_config: C.TrainingGraphTransformerConfiguration, inputs: Tuple, kwargs: Dict
Expand Down Expand Up @@ -709,7 +766,7 @@ def _enable_conditional_optimizations(
else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers

_, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers(
_, _, label_sparsity_results = _io._combine_input_buffers_initializers(
param_to_append_as_onnx_graph_inputs,
self._graph_builder.get_graph_info().user_input_names,
self._input_info,
Expand All @@ -729,16 +786,13 @@ def _enable_conditional_optimizations(
[f"{k}:{v:.0f}%" for k, v in label_sparsity_results.items()]
)

if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0:
if detected_device.type == "cuda":
# Embedding sparsity optimization is only supported on CUDA devices.
graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys())
self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results)
self._runtime_options.embed_sparsity_ratio = ",".join(
[f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()]
)
else:
self._logger.info("Embedding sparsity-based optimization is not supported on non-CUDA devices.")
if self._runtime_inspector._embedding_module_to_padding_density_map:
self._runtime_options.embed_sparsity_ratio = ",".join(
[
f"{v[0]}:{v[1]:.0f}%"
for v in self._runtime_inspector._embedding_module_to_padding_density_map.values()
]
)

# If users don't want to print input density, disable the input density observer to avoid overhead
# when looping through inputs during training.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, logger: Logger, module: torch.nn.Module, training: bool):

self.input_density_ob: Union[InputDensityObserver, None] = None
self.memory_ob = MemoryObserver(module, self._logger, training)
self._embedding_module_to_padding_density_map = {}

def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None:
"""Initialize input inspector from the given ONNX model and user input names.
Expand Down Expand Up @@ -747,3 +748,33 @@ def _get_user_config_without_freq(configs: str):
return notes, mem_tbl

return [], None


class FlagPaddingElimination(torch.autograd.Function):
pengwa marked this conversation as resolved.
Show resolved Hide resolved
"""
FlagPaddingElimination is a PyTorch autograd function that does nothing in forward pass and backward pass.
It is used as a flag to tell the GraphTransformer of PaddingElimination to modify the graph to eliminate
the embedding padding.
"""

@staticmethod
def forward(ctx, input):
return input

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output

@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
return tensor_input_shapes, tensor_input_dtypes

@staticmethod
def alias_input(node_proto_str: str):
fw_alias_map = [0]
bw_alias_map = [0]
return fw_alias_map, bw_alias_map
Loading