Skip to content

Commit

Permalink
Check padding density by input of embedding module
Browse files Browse the repository at this point in the history
  • Loading branch information
guyang3532 committed Mar 11, 2024
1 parent 5479124 commit f3eb16d
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ namespace onnxruntime {
namespace {

// TODO(pengwa): remove this once customized PythonOp shape inference is supported.
constexpr const char* kInspectActivationFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation";
constexpr const char* kIncrementStepFuncName = "onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep";
constexpr const char* kInspectActivationFuncName =
"onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation";
constexpr const char* kIncrementStepFuncName =
"onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep";
constexpr const char* kFlagPaddingEliminationFuncName =
"onnxruntime.training.ortmodule._graph_execution_manager._FlagPaddingElimination";

void PushAllOutputNode(Graph& graph, std::queue<Node*>& q, Node* node, std::unordered_set<Node*>& visited) {
for (auto iter = node->OutputNodesBegin(); iter != node->OutputNodesEnd(); ++iter) {
Expand Down Expand Up @@ -311,7 +315,7 @@ void IterateSubgraphFromNode(Graph& graph,
candidate_outputs.insert(cur);
continue;
}
auto func_name = static_cast<std::string>(cur->GetAttributes().at("name").s());
auto func_name = static_cast<std::string>(cur->GetAttributes().at("func_name").s());
if (func_name == kInspectActivationFuncName || func_name == kIncrementStepFuncName) {
subgraph.insert(cur->MutableOutputDefs()[1]);
PushAllOutputNode(graph, to_visit, cur, visited);
Expand Down Expand Up @@ -353,11 +357,6 @@ void IterateSubgraphFromNode(Graph& graph,
Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
LOG_DEBUG_INFO(logger, "Enter PaddingElimination");

if (sparse_embedding_input_names_.size() == 0) {
LOG_DEBUG_INFO(logger, "Exit PaddingElimination, no sparse embedding input names.");
return Status::OK();
}

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
Node* embedding_node = nullptr;
Expand Down Expand Up @@ -386,13 +385,31 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
node.InputDefs()[2]->Exists() &&
graph_utils::IsConstantInitializer(graph, node.InputDefs()[2]->Name()) &&
node.InputDefs()[1]->Exists() &&
graph_utils::IsGraphInput(graph, node.InputDefs()[1]) &&
node.InputDefs()[1]->Shape() &&
node.InputDefs()[1]->Shape()->dim_size() >= 2) {
if (std::find(sparse_embedding_input_names_.begin(), sparse_embedding_input_names_.end(),
node.InputDefs()[1]->Name()) == sparse_embedding_input_names_.end()) {
LOG_DEBUG_INFO(logger, "Skip node " + node.Name() + "(" + node.OpType() +
") due to embedding input is not in the sparse embedding input list.");
const auto outputNodeCount = std::distance(node.OutputEdgesBegin(), node.OutputEdgesEnd());
if (outputNodeCount != 1) {
continue;
}
auto embedding_output_node = graph.GetNode(node.OutputNodesBegin()->Index());
if (embedding_output_node == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_output_node, "PythonOp", {1}, kMSDomain) ||
static_cast<std::string>(embedding_output_node->GetAttributes().at("func_name").s()) !=
kFlagPaddingEliminationFuncName) {
LOG_DEBUG_INFO(logger, "not find PythonOp of flagPaddingElimination after embedding node");
continue;
}
if (graph_utils::CanRemoveNode(graph, *embedding_output_node, logger)) {
if (graph_utils::RemoveNode(graph, *embedding_output_node)) {
modified = true;
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_output_node->Name() +
"(" + embedding_output_node->OpType() + ")");
continue;
}
} else {
LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_output_node->Name() +
"(" + embedding_output_node->OpType() + ")");
continue;
}
const ONNX_NAMESPACE::TensorProto* padding_initializer =
Expand Down Expand Up @@ -479,7 +496,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
// to flattern the shape of [batch_size, seqlen, ...] to [valid_token_count, ...]
InsertFlattenPatternForInput(graph, *embedding_node, 1, squeeze_out_arg, logger);
handled_input_count++;
modified = true;
for (auto& node : candidate_inputs) {
for (uint32_t i = 0; i < node->InputDefs().size(); ++i) {
if (subgraph.find(node->MutableInputDefs()[i]) == subgraph.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,10 @@ namespace onnxruntime {
*/
class PaddingElimination : public GraphTransformer {
public:
PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const std::vector<std::string>& sparse_embedding_input_names = {}) noexcept
: GraphTransformer("PaddingElimination", compatible_execution_providers),
sparse_embedding_input_names_{sparse_embedding_input_names} {}
explicit PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("PaddingElimination", compatible_execution_providers) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

private:
std::vector<std::string> sparse_embedding_input_names_;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat
// Enable compute optimizer.
bool enable_compute_optimizer{false};

// Enable embedding sparsity compute optimization for the input names in the below list.
std::vector<std::string> sparse_embedding_input_names;

// Enable label sparsity compute optimization for the input names in the below list.
std::vector<std::string> sparse_label_input_names;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
#if defined(USE_CUDA) || defined(USE_ROCM)
// Put this under CUDA/ROCM guard as it depends on PadAndUnflatten CUDA/ROCM kernel.
// Once we have a CPU kernel for PadAndUnflatten, we can remove the guard.
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps,
config.sparse_embedding_input_names));
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps));
transformers.emplace_back(std::make_unique<Conv1dReplacement>(compatible_eps));
#endif
}
Expand Down
1 change: 0 additions & 1 deletion orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers)
.def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer)
.def_readwrite("sparse_embedding_input_names", &TrainingGraphTransformerConfiguration::sparse_embedding_input_names)
.def_readwrite("sparse_label_input_names", &TrainingGraphTransformerConfiguration::sparse_label_input_names)
.def_readwrite("optimized_pre_grad_filepath", &TrainingGraphTransformerConfiguration::optimized_pre_grad_filepath)
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
from abc import ABC, abstractmethod # noqa: F401
from hashlib import md5 as hash_fn
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import onnx
import torch
Expand Down Expand Up @@ -49,6 +49,30 @@ def __init__(self, state, output_info: List[Tuple[torch.Size, torch.device, torc
self.output_info = output_info


class _FlagPaddingElimination(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output

@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
return tensor_input_shapes, tensor_input_dtypes

@staticmethod
def alias_input(node_proto_str: str):
fw_alias_map = [0]
bw_alias_map = [0]
return fw_alias_map, bw_alias_map


class GraphExecutionManager(GraphExecutionInterface):
def __init__(
self,
Expand Down Expand Up @@ -91,6 +115,7 @@ def __init__(
# Inspector for runtime information, for example input data, memory usage, etc.
self._runtime_inspector = RuntimeInspector(self._logger, self._original_module)
self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step)
self._embedding_module_to_padding_density_map = {}

# Tracker for ORTModule model export, session creation overhead.
self.time_tracker = _logger.TimeTracker()
Expand Down Expand Up @@ -622,6 +647,43 @@ def __setstate__(self, state):

_utils.reinitialize_graph_execution_manager(self)

def _check_embedding_sparsity(self):
if not self._runtime_options.enable_embedding_sparse_optimizer or self._device.type != "cuda":
return

def embedding_hook(module, args, output):
ebd_input = args[0]
if ebd_input is None or not isinstance(ebd_input, torch.Tensor):
self._logger.warning("Embedding input is not a tensor.")
return None

valid_token = torch.count_nonzero(ebd_input - module.padding_idx)
total_token = ebd_input.numel()
embed_density = float(valid_token) / float(total_token) * 100
if embed_density < 90:
self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density)
if module not in self._embedding_module_to_padding_density_map:
self._logger.warning("Found Embedding module not in the map. %s", module)
return None
if (
module in self._embedding_module_to_padding_density_map
and self._embedding_module_to_padding_density_map[module][1] != -1
):
self._logger.warning(
"Found duplicate Embedding module. %s", self._embedding_module_to_padding_density_map[module][0]
)
self._embedding_module_to_padding_density_map[module][1] = embed_density
return _FlagPaddingElimination.apply(output)
else:
self._logger.info("Embedding sparsity-based optimization is OFF for density: %.0f%%", embed_density)
return None

for name, sub_module in self._flattened_module.named_modules():
if isinstance(sub_module, torch.nn.modules.sparse.Embedding):
if sub_module.padding_idx is not None and sub_module.padding_idx >= 0:
self._embedding_module_to_padding_density_map[sub_module] = [name, -1]
sub_module.register_forward_hook(embedding_hook)

@_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION)
def _enable_conditional_optimizations(
self, graph_transformer_config: C.TrainingGraphTransformerConfiguration, inputs: Tuple, kwargs: Dict
Expand Down Expand Up @@ -680,16 +742,10 @@ def _enable_conditional_optimizations(
[f"{k}:{v:.0f}%" for k, v in label_sparsity_results.items()]
)

if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0:
if detected_device.type == "cuda":
# Embedding sparsity optimization is only supported on CUDA devices.
graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys())
self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results)
self._runtime_options.embed_sparsity_ratio = ",".join(
[f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()]
)
else:
self._logger.info("Embedding sparsity-based optimization is not supported on non-CUDA devices.")
if self._embedding_module_to_padding_density_map:
self._runtime_options.embed_sparsity_ratio = ",".join(
[f"{v[0]}:{v[1]:.0f}%" for v in self._embedding_module_to_padding_density_map.values()]
)

# If users don't want to print input density, disable the input density observer to avoid overhead
# when looping through inputs during training.
Expand Down
4 changes: 1 addition & 3 deletions orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,8 @@ def _expand_inputs(current_input, non_none_inputs, name=""):
if PrimitiveType.is_primitive_type(inp):
inp = PrimitiveType.get_tensor(inp, device)

found, embedding_density, label_density = rt_inspector.inspect_input(name, inp)
found, _, label_density = rt_inspector.inspect_input(name, inp)
if found:
if embedding_density < 100:
embed_sparsity_results[name] = embedding_density
if label_density < 100:
label_sparsity_results[name] = label_density
result.append(inp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def initialize(self, model: ModelProto, user_input_names: List[str]) -> None:
if output_name != "":
self._tensor_to_node_map[output_name] = node

self._initialize_embedding_padding_inspector(model, user_input_names)
self._initialize_loss_label_padding_inspector(model, user_input_names)

self._is_initialized = True
Expand Down Expand Up @@ -354,33 +353,6 @@ def _inspect_embed_label_input(self, name, data):
found = False
min_embed_density = 100
min_label_density = 100
if (
len(self._embedding_graph_input_to_padding_idx_map) > 0
and name in self._embedding_graph_input_to_padding_idx_map
and isinstance(data, torch.Tensor)
):
for padding_idx in self._embedding_graph_input_to_padding_idx_map[name]:
valid_token = torch.count_nonzero(data - padding_idx)
valid_token_per_batch = "N/A"
if data.dim() > 1:
valid_token_per_batch = str(torch.count_nonzero(data - padding_idx, dim=1).tolist())
total_token = data.numel()
embed_density = float(valid_token) / float(total_token) * 100
if embed_density < 90:
min_embed_density = min(min_embed_density, embed_density)
self._stats.append(
[
self._current_step,
"EMBED",
name,
padding_idx,
embed_density,
valid_token,
total_token,
valid_token_per_batch,
]
)
found = True

if (
len(self._loss_label_graph_input_to_ignore_idx_map) > 0
Expand Down Expand Up @@ -697,9 +669,11 @@ def _get_user_config_without_freq(configs: str):
[
f" - Plan {index}",
":",
"ON"
if all(cluster_id in user_configs_with_out_freq for cluster_id in cluster_ids_without_freq)
else "OFF",
(
"ON"
if all(cluster_id in user_configs_with_out_freq for cluster_id in cluster_ids_without_freq)
else "OFF"
),
":",
cluster_id,
saving_symbolic.freq if details else "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def forward(self, *inputs, **kwargs):
):
self.time_tracker.start(ORTModuleInitPhase.EndToEnd)

self._check_embedding_sparsity()

build_gradient_graph = self._export_model(*inputs, **kwargs)

if build_gradient_graph:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5705,8 +5705,6 @@ def run_step(model, input, target):
@pytest.mark.parametrize("label_is_sparse", [False, True])
@pytest.mark.parametrize("rank", [1, 2])
def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, label_is_sparse, rank, caplog):
os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1"

class NeuralNetCrossEntropyLoss(torch.nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
Expand Down Expand Up @@ -5801,7 +5799,6 @@ def run_step(model, input, positions):
],
)
def test_ops_for_padding_elimination(test_cases):
os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1"
test_op = test_cases[0]
case = test_cases[1]

Expand Down Expand Up @@ -5962,11 +5959,8 @@ def find_input_node_type(model, arg):
else:
assert "ATen" in recover_pad_input_optypes

del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]


def test_e2e_padding_elimination():
os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1"
seed = 5033
random.seed(seed)
np.random.seed(seed)
Expand Down Expand Up @@ -6109,7 +6103,6 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model
assert "FlattenAndUnpad" in [node.op_type for node in training_model.graph.node]
assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node]
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]


@pytest.mark.skipif(
Expand Down

0 comments on commit f3eb16d

Please sign in to comment.