Skip to content

Commit

Permalink
Fuse Cast + SoftmaxCrossEntropyLossInternal (#20334)
Browse files Browse the repository at this point in the history
### Description
Fuse Cast + SoftmaxCrossEntropyLossInternal to
SoftmaxCrossEntropyLossInternal.
  • Loading branch information
guyang3532 authored Apr 29, 2024
1 parent 923b0ef commit 3e4db2c
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 6 deletions.
17 changes: 13 additions & 4 deletions onnxruntime/core/graph/graph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,7 @@ static bool RemoveNodeWithSingleNodeInSingleUsedOutput(Graph& graph, Node& node)
return true;
}

/** Move the input edges that src_node has to target_node.
After the move is complete src_node will have no input edges.
*/
static void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) {
void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) {
auto target_idx = target_node.Index();
auto input_edges = GraphEdge::GetNodeInputEdges(src_node);

Expand Down Expand Up @@ -387,6 +384,18 @@ std::vector<GraphEdge> GraphEdge::GetNodeInputEdges(const Node& node) {
return input_edges;
}

/** Returns a vector of the input GraphEdges of a node for the provided input index. */
std::vector<GraphEdge> GraphEdge::GetNodeInputEdges(const Node& node, size_t index) {
std::vector<GraphEdge> input_edges;
for (auto it = node.InputEdgesBegin(), end = node.InputEdgesEnd(); it != end; ++it) {
if (static_cast<size_t>(it->GetDstArgIndex()) == index) {
input_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, true));
}
}

return input_edges;
}

/** Returns a vector of the output GraphEdges of a node. */
std::vector<GraphEdge> GraphEdge::GetNodeOutputEdges(const Node& node) {
std::vector<GraphEdge> output_edges;
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/graph/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ const std::string& GetNodeOutputName(const Node& node, int index);
*/
const Node::EdgeEnd* GetInputEdge(const Node& node, int arg_index);

/** Move the input edges that src_node has to target_node.
After the move is complete src_node will have no input edges.
*/
void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node);

/** Removes all output edges from the given Node of the Graph.
This should probably be elevated to the Graph API eventually. */
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
Expand Down Expand Up @@ -89,6 +94,9 @@ struct GraphEdge {
/** Returns a vector of the input GraphEdges of a node. */
static std::vector<GraphEdge> GetNodeInputEdges(const Node& node);

/** Returns a vector of the input GraphEdges of a node for the provided input index. */
static std::vector<GraphEdge> GetNodeInputEdges(const Node& node, size_t index);

/** Returns a vector of the output GraphEdges of a node. */
static std::vector<GraphEdge> GetNodeOutputEdges(const Node& node);

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/optimizer/gemm_transpose_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
new_gemm_node.AddAttribute("alpha", gemm_node.GetAttributes().at("alpha").f());
new_gemm_node.AddAttribute("beta", gemm_node.GetAttributes().at("beta").f());

new_gemm_node.SetExecutionProviderType(gemm_node.GetExecutionProviderType());

graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, new_gemm_node);

modified = RewriteRuleEffect::kRemovedCurrentNode;
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
break;

case TransformerLevel::Level2:
rules.push_back(std::make_unique<GemmTransposeFusion>());
// No level2 rules available today
break;

Expand Down Expand Up @@ -253,6 +254,11 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
} break;

case TransformerLevel::Level2: {
auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {});
if (rule_transformer != nullptr) {
transformers.emplace_back(std::move(rule_transformer));
}

// we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be
// applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2).
transformers.emplace_back(std::make_unique<TransposeOptimizer>(std::move(cpu_allocator), kCpuExecutionProvider));
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/propagate_cast_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ static bool IsFP16Allow(const Node* node, size_t level, const FP16AllowOps& fp16

using OpsSetType = InlinedHashSet<std::string_view>;
static const OpsSetType level1_fp16_allow_set =
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu"};
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu", "Slice", "PadAndUnflatten"};
static const OpsSetType level2_fp16_allow_set = {
"Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "LayerNormalization", "Where"};

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNorm
// (plus ShrunkenGather for training) are considered deterministic.
#ifdef ENABLE_TRAINING_OPS
constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear",
"ConcatTraining"};
"ConcatTraining", "PadAndUnflatten"};
#else
constexpr std::array kMSDomainDeterministicOps{"QuantizeLinear", "DequantizeLinear"};
#endif
Expand Down
59 changes: 59 additions & 0 deletions orttraining/orttraining/core/optimizer/cast_sce_loss_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"

#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"

namespace onnxruntime {

Status CastSceLossFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

for (auto node_index : node_topology_list) {
auto* node_ptr = graph.GetNode(node_index);
if (!node_ptr) continue; // Node was removed.

auto& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));

bool is_internal_sce = graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLossInternal", {1},
kMSDomain);

if (!is_internal_sce || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
continue;
}

Node* input_node = graph.GetMutableProducerNode(node.MutableInputDefs()[0]->Name());

if (!(graph_utils::IsSupportedOptypeVersionAndDomain(*input_node, "Cast", {9, 13, 19}))) {
continue;
}

if (input_node->GetOutputEdgesCount() != 1 || graph.IsOutput(input_node->OutputDefs()[0])) {
continue;
}

if (input_node->MutableInputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT16 &&
input_node->MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT) {
std::vector<graph_utils::GraphEdge> input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node, 0);
graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges);
node.MutableInputDefs()[0] = input_node->MutableInputDefs()[0];
graph_utils::MoveAllNodeInputEdges(graph, *input_node, node);
graph.RemoveNode(input_node->Index());

if (node.GetAttributes().count("output_type") == 0) {
node.AddAttribute("output_type", static_cast<int64_t>(onnx::TensorProto_DataType_FLOAT));
}
modified = true;
}
}

return Status::OK();
}

} // namespace onnxruntime
23 changes: 23 additions & 0 deletions orttraining/orttraining/core/optimizer/cast_sce_loss_fusion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/graph_transformer.h"

namespace onnxruntime {

/**
@Class CastSceLossFusion
Fuse Cast + SoftmaxCrossEntropyLossInternal to SoftmaxCrossEntropyLossInternal.
*/
class CastSceLossFusion : public GraphTransformer {
public:
explicit CastSceLossFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("CastSceLossFusion", compatible_execution_providers) {
}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "orttraining/core/framework/distributed_run_context.h"
#include "orttraining/core/optimizer/batchnorm_replacement.h"
#include "orttraining/core/optimizer/bitmask_dropout_replacement.h"
#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"
#include "orttraining/core/optimizer/concat_replacement.h"
#include "orttraining/core/optimizer/graph_transformer_registry.h"
#include "orttraining/core/optimizer/gru_replacement.h"
Expand Down Expand Up @@ -188,6 +189,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
config.propagate_cast_ops_config.allow,
cuda_execution_provider));
}
transformers.emplace_back(std::make_unique<CastSceLossFusion>(compatible_eps));

if (config.enable_compute_optimizer) {
transformers.emplace_back(std::make_unique<UpStreamGatherGraphTransformer>(compatible_eps));
Expand Down
17 changes: 17 additions & 0 deletions orttraining/orttraining/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "test/util/include/asserts.h"
#include "orttraining/test/optimizer/horizontal_parallel_test_utils.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"
#include "orttraining/core/optimizer/loss_rewriter.h"
#include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h"
#include "orttraining/core/optimizer/qdq_fusion.h"
Expand Down Expand Up @@ -518,6 +519,22 @@ TEST_F(GraphTransformationTests, SceLossGradBiasFusion_Invalid) {
}
}

TEST_F(GraphTransformationTests, CastSceLossFusion) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "computation_reduction/reshape/mlm_bert_e2e.onnx";
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Cast"], 10);

onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<CastSceLossFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Cast"], 9);
}

Node* GetNodeByName(Graph& graph, std::string node_name) {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
Expand Down

0 comments on commit 3e4db2c

Please sign in to comment.