Skip to content

Commit

Permalink
Generalize label input sparsity check and refactor (#20636)
Browse files Browse the repository at this point in the history
### Description
The InsertGatherBeforeSceLoss optimization is enabled when the density
of label padding less than 90%. We need to check the density of the
label padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with
the SCE node by iterate graph from the SCE node back to one graph input.
This is hard to be general because there may be complicated pattern
between graph input and SCE node.

This pr check padding density by the direct input of SCE module rather
than the input of graph at the first graph execution when exporting onnx
graph.
And if the density < 90%, insert a flag PythonOp after the SCE node as:
```
           SoftmaxCrossEntropy
		  |
            PythonOp (func_name: FlagAndPrintDensity)   (insert if density < 90%)
		  |
            Following graph
```

When the InsertGatherBeforeSceLoss is invoked, it check if there is the
flag PythonOp(func_name: FlagAndPrintDensity) after the SCE node and if
it is, remove it and do the padding elimination optimization.

If the env of ORTMODULE_PRINT_INPUT_DENSITY is 1, we will print input
density each step by the PythonOp (func_name: FlagAndPrintDensity). In
this case the PythonOp will not be removed.
  • Loading branch information
guyang3532 authored May 10, 2024
1 parent e124cf8 commit cfe830b
Show file tree
Hide file tree
Showing 16 changed files with 303 additions and 607 deletions.
24 changes: 11 additions & 13 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,6 @@ debugging).
export ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=0 # Disable
```

#### ORTMODULE_ENABLE_SPARSE_OPTIMIZER

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the input data sparsity
based performance optimizations, including embedding sparsity and label sparsity.
This optimization is applicable when using optimum, which has an implementation of the ModuleWithLoss class that wraps the HuggingFace Training that allows loss computation inside ONNX Runtime (ORT).
If you're not using optimum but want to implement a similar wrapper in your codebase to compute the loss inside ONNX Runtime (ORT), you can refer to this [Link](ORTModule_ModuleWithLoss_Wrapper.md) for detailed steps and guidelines on how to achieve this.

```bash
export ORTMODULE_ENABLE_SPARSE_OPTIMIZER=1 # Enable
export ORTMODULE_ENABLE_SPARSE_OPTIMIZER=0 # Disable
```

#### ORTMODULE_PRINT_INPUT_DENSITY

- **Feature Area**: *ORTMODULE/RuntimeInspector*
Expand Down Expand Up @@ -254,6 +241,17 @@ data sparsity based performance optimizations.
export ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER=0 # Disable
```

#### ORTMODULE_ENABLE_LABEL_SPARSE_OPTIMIZER

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the label input
data sparsity based performance optimizations.

```bash
export ORTMODULE_ENABLE_LABEL_SPARSE_OPTIMIZER=1 # Enable
export ORTMODULE_ENABLE_LABEL_SPARSE_OPTIMIZER=0 # Disable
```

#### ORTMODULE_CACHE_DIR

- **Feature Area**: *ORTMODULE/RuntimeOptions*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ constexpr const char* kInspectActivationFuncName =
"onnxruntime.training.utils.hooks._statistics_subscriber._InspectActivation";
constexpr const char* kIncrementStepFuncName =
"onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep";
constexpr const char* kFlagPaddingEliminationFuncName =
"onnxruntime.training.ortmodule._runtime_inspector.FlagPaddingElimination";
constexpr const char* kFlagAndPrintDensityFuncName =
"onnxruntime.training.ortmodule._runtime_inspector.FlagAndPrintDensity";

void PushAllOutputNode(Graph& graph, std::queue<Node*>& q, Node* node, std::unordered_set<Node*>& visited) {
for (auto iter = node->OutputNodesBegin(); iter != node->OutputNodesEnd(); ++iter) {
Expand Down Expand Up @@ -396,26 +396,28 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
if (outputNodeCount != 1) {
continue;
}
auto embedding_output_node = graph.GetNode(node.OutputNodesBegin()->Index());
if (embedding_output_node == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_output_node, "PythonOp", {1}, kMSDomain) ||
static_cast<std::string>(embedding_output_node->GetAttributes().at("func_name").s()) !=
kFlagPaddingEliminationFuncName) {
Node* embedding_input_node = graph.GetMutableProducerNode(node.MutableInputDefs()[1]->Name());
if (embedding_input_node == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_input_node, "PythonOp", {1}, kMSDomain) ||
static_cast<std::string>(embedding_input_node->GetAttributes().at("func_name").s()) !=
kFlagAndPrintDensityFuncName) {
LOG_DEBUG_INFO(logger, "not find PythonOp of flagPaddingElimination after embedding node");
continue;
}
if (graph_utils::CanRemoveNode(graph, *embedding_output_node, logger)) {
if (graph_utils::RemoveNode(graph, *embedding_output_node)) {
modified = true;
if (!print_density_) {
if (graph_utils::CanRemoveNode(graph, *embedding_input_node, logger)) {
if (graph_utils::RemoveNode(graph, *embedding_input_node)) {
modified = true;
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_input_node->Name() +
"(" + embedding_input_node->OpType() + ")");
continue;
}
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_output_node->Name() +
"(" + embedding_output_node->OpType() + ")");
LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_input_node->Name() +
"(" + embedding_input_node->OpType() + ")");
continue;
}
} else {
LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_output_node->Name() +
"(" + embedding_output_node->OpType() + ")");
continue;
}
const ONNX_NAMESPACE::TensorProto* padding_initializer =
graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,15 @@ namespace onnxruntime {
*/
class PaddingElimination : public GraphTransformer {
public:
explicit PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("PaddingElimination", compatible_execution_providers) {}
explicit PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const bool print_input_density = false) noexcept
: GraphTransformer("PaddingElimination", compatible_execution_providers),
print_density_(print_input_density) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

private:
bool print_density_ = false;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@

namespace onnxruntime {

namespace {

constexpr const char* kFlagAndPrintDensityFuncName =
"onnxruntime.training.ortmodule._runtime_inspector.FlagAndPrintDensity";
} // namespace

Status InsertGatherBeforeSceLoss::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/,
const logging::Logger& logger) const {
LOG_DEBUG_INFO(logger, "Enter InsertGatherBeforeSceLoss");

if (sparse_label_input_names_.size() == 0) {
LOG_DEBUG_INFO(logger, "Exit InsertGatherBeforeSceLoss, no sparse label input names.");
return Status::OK();
}

GraphViewer graph_viewer(graph);
[[maybe_unused]] size_t handled_sce_node_count = 0; // For summary
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
Expand All @@ -48,7 +49,7 @@ Status InsertGatherBeforeSceLoss::ApplyImpl(Graph& graph, bool& modified, int /*
const NodeArg* label_input_arg = node.InputDefs()[1];

// Check whether this SCE node is handled or not.
const Node* labels_producer = graph.GetProducerNode(label_input_arg->Name());
Node* labels_producer = graph.GetMutableProducerNode(label_input_arg->Name());
// Skip if already inserted a ShrunkenGather node.
if (labels_producer && graph_utils::IsSupportedOptypeVersionAndDomain(
*labels_producer, "ShrunkenGather", {1}, kMSDomain)) {
Expand All @@ -57,18 +58,28 @@ Status InsertGatherBeforeSceLoss::ApplyImpl(Graph& graph, bool& modified, int /*
continue;
}

// Label input can be a graph input or from a Reshape node taking a graph input as its data input.
if (labels_producer && graph_utils::IsSupportedOptypeVersionAndDomain(
*labels_producer, "Reshape", {1, 5, 13, 14}, kOnnxDomain)) {
label_input_arg = labels_producer->InputDefs()[0];
}
// Then check if the label input is graph input and in the sparse label input list.
if (!graph.IsInputsIncludingInitializers(label_input_arg) ||
std::find(sparse_label_input_names_.begin(), sparse_label_input_names_.end(),
label_input_arg->Name()) == sparse_label_input_names_.end()) {
if (labels_producer == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*labels_producer, "PythonOp", {1}, kMSDomain) ||
static_cast<std::string>(labels_producer->GetAttributes().at("func_name").s()) !=
kFlagAndPrintDensityFuncName) {
LOG_DEBUG_INFO(logger, "Skip node " + node.Name() + "(" + node.OpType() +
") due to labels input is not a graph input or not in the sparse label input list.");
") due to labels input is not produced by a PythonOp node with flag " +
kFlagAndPrintDensityFuncName + ".");
continue;
} else if (!print_density_) {
if (graph_utils::CanRemoveNode(graph, *labels_producer, logger)) {
if (graph_utils::RemoveNode(graph, *labels_producer)) {
modified = true;
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + labels_producer->Name() +
"(" + labels_producer->OpType() + ")");
continue;
}
} else {
LOG_DEBUG_INFO(logger, "Can not remove node " + labels_producer->Name() +
"(" + labels_producer->OpType() + ")");
continue;
}
}

// Check shape requirements.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ namespace onnxruntime {
class InsertGatherBeforeSceLoss : public GraphTransformer {
public:
InsertGatherBeforeSceLoss(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const std::vector<std::string>& sparse_label_input_names = {}) noexcept
const bool print_input_density = false) noexcept
: GraphTransformer("InsertGatherBeforeSceLoss", compatible_execution_providers),
sparse_label_input_names_{sparse_label_input_names} {
print_density_(print_input_density) {
}

/**
Expand All @@ -79,7 +79,7 @@ class InsertGatherBeforeSceLoss : public GraphTransformer {
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

private:
std::vector<std::string> sparse_label_input_names_;
bool print_density_ = false;
};

} // namespace onnxruntime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat
// Enable compute optimizer.
bool enable_compute_optimizer{false};

// Enable label sparsity compute optimization for the input names in the below list.
std::vector<std::string> sparse_label_input_names;
bool print_input_density{false};

// Path for serialization of the transformed optimized model. If empty, serialization is disabled.
std::string optimized_pre_grad_filepath;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,12 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(std::make_unique<UpStreamGatherGraphTransformer>(compatible_eps));
transformers.emplace_back(std::make_unique<UpStreamReshapeGraphTransformer>(compatible_eps));
transformers.emplace_back(std::make_unique<InsertGatherBeforeSceLoss>(compatible_eps,
config.sparse_label_input_names));
config.print_input_density));
#if defined(USE_CUDA) || defined(USE_ROCM)
// Put this under CUDA/ROCM guard as it depends on PadAndUnflatten CUDA/ROCM kernel.
// Once we have a CPU kernel for PadAndUnflatten, we can remove the guard.
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps));
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps,
config.print_input_density));
transformers.emplace_back(std::make_unique<Conv1dReplacement>(compatible_eps));
#endif
}
Expand Down
2 changes: 1 addition & 1 deletion orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers)
.def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer)
.def_readwrite("sparse_label_input_names", &TrainingGraphTransformerConfiguration::sparse_label_input_names)
.def_readwrite("print_input_density", &TrainingGraphTransformerConfiguration::print_input_density)
.def_readwrite("optimized_pre_grad_filepath", &TrainingGraphTransformerConfiguration::optimized_pre_grad_filepath)
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);

Expand Down
Loading

0 comments on commit cfe830b

Please sign in to comment.