diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 20e8161ee79fd..904a2a45ac8f9 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -172,10 +172,7 @@ static bool RemoveNodeWithSingleNodeInSingleUsedOutput(Graph& graph, Node& node) return true; } -/** Move the input edges that src_node has to target_node. -After the move is complete src_node will have no input edges. -*/ -static void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) { +void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) { auto target_idx = target_node.Index(); auto input_edges = GraphEdge::GetNodeInputEdges(src_node); @@ -387,6 +384,18 @@ std::vector GraphEdge::GetNodeInputEdges(const Node& node) { return input_edges; } +/** Returns a vector of the input GraphEdges of a node for the provided input index. */ +std::vector GraphEdge::GetNodeInputEdges(const Node& node, size_t index) { + std::vector input_edges; + for (auto it = node.InputEdgesBegin(), end = node.InputEdgesEnd(); it != end; ++it) { + if (static_cast(it->GetDstArgIndex()) == index) { + input_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, true)); + } + } + + return input_edges; +} + /** Returns a vector of the output GraphEdges of a node. */ std::vector GraphEdge::GetNodeOutputEdges(const Node& node) { std::vector output_edges; diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index cf76ec785fda5..87025c233b71c 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -59,6 +59,11 @@ const std::string& GetNodeOutputName(const Node& node, int index); */ const Node::EdgeEnd* GetInputEdge(const Node& node, int arg_index); +/** Move the input edges that src_node has to target_node. +After the move is complete src_node will have no input edges. +*/ +void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node); + /** Removes all output edges from the given Node of the Graph. This should probably be elevated to the Graph API eventually. */ size_t RemoveNodeOutputEdges(Graph& graph, Node& node); @@ -89,6 +94,9 @@ struct GraphEdge { /** Returns a vector of the input GraphEdges of a node. */ static std::vector GetNodeInputEdges(const Node& node); + /** Returns a vector of the input GraphEdges of a node for the provided input index. */ + static std::vector GetNodeInputEdges(const Node& node, size_t index); + /** Returns a vector of the output GraphEdges of a node. */ static std::vector GetNodeOutputEdges(const Node& node); diff --git a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc index a52517d23db86..d1f862460dae7 100644 --- a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc @@ -87,6 +87,8 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m new_gemm_node.AddAttribute("alpha", gemm_node.GetAttributes().at("alpha").f()); new_gemm_node.AddAttribute("beta", gemm_node.GetAttributes().at("beta").f()); + new_gemm_node.SetExecutionProviderType(gemm_node.GetExecutionProviderType()); + graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, new_gemm_node); modified = RewriteRuleEffect::kRemovedCurrentNode; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 51fbbf0b79980..9341ade0a2f1d 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -138,6 +138,7 @@ InlinedVector> GenerateRewriteRules( break; case TransformerLevel::Level2: + rules.push_back(std::make_unique()); // No level2 rules available today break; @@ -253,6 +254,11 @@ InlinedVector> GenerateTransformers( } break; case TransformerLevel::Level2: { + auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}); + if (rule_transformer != nullptr) { + transformers.emplace_back(std::move(rule_transformer)); + } + // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); diff --git a/onnxruntime/core/optimizer/propagate_cast_ops.cc b/onnxruntime/core/optimizer/propagate_cast_ops.cc index e4f34e066851f..b129410c82914 100644 --- a/onnxruntime/core/optimizer/propagate_cast_ops.cc +++ b/onnxruntime/core/optimizer/propagate_cast_ops.cc @@ -171,7 +171,7 @@ static bool IsFP16Allow(const Node* node, size_t level, const FP16AllowOps& fp16 using OpsSetType = InlinedHashSet; static const OpsSetType level1_fp16_allow_set = - {"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu"}; + {"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu", "Slice", "PadAndUnflatten"}; static const OpsSetType level2_fp16_allow_set = { "Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "LayerNormalization", "Where"}; diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 7055882961e17..c7e11de34858a 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -281,7 +281,7 @@ constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNorm // (plus ShrunkenGather for training) are considered deterministic. #ifdef ENABLE_TRAINING_OPS constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear", - "ConcatTraining"}; + "ConcatTraining", "PadAndUnflatten"}; #else constexpr std::array kMSDomainDeterministicOps{"QuantizeLinear", "DequantizeLinear"}; #endif diff --git a/orttraining/orttraining/core/optimizer/cast_sce_loss_fusion.cc b/orttraining/orttraining/core/optimizer/cast_sce_loss_fusion.cc new file mode 100644 index 0000000000000..cb692d347d706 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/cast_sce_loss_fusion.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/optimizer/cast_sce_loss_fusion.h" + +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +Status CastSceLossFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (!node_ptr) continue; // Node was removed. + + auto& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + bool is_internal_sce = graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLossInternal", {1}, + kMSDomain); + + if (!is_internal_sce || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + continue; + } + + Node* input_node = graph.GetMutableProducerNode(node.MutableInputDefs()[0]->Name()); + + if (!(graph_utils::IsSupportedOptypeVersionAndDomain(*input_node, "Cast", {9, 13, 19}))) { + continue; + } + + if (input_node->GetOutputEdgesCount() != 1 || graph.IsOutput(input_node->OutputDefs()[0])) { + continue; + } + + if (input_node->MutableInputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT16 && + input_node->MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT) { + std::vector input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node, 0); + graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges); + node.MutableInputDefs()[0] = input_node->MutableInputDefs()[0]; + graph_utils::MoveAllNodeInputEdges(graph, *input_node, node); + graph.RemoveNode(input_node->Index()); + + if (node.GetAttributes().count("output_type") == 0) { + node.AddAttribute("output_type", static_cast(onnx::TensorProto_DataType_FLOAT)); + } + modified = true; + } + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/cast_sce_loss_fusion.h b/orttraining/orttraining/core/optimizer/cast_sce_loss_fusion.h new file mode 100644 index 0000000000000..1447bbea9774c --- /dev/null +++ b/orttraining/orttraining/core/optimizer/cast_sce_loss_fusion.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class CastSceLossFusion +Fuse Cast + SoftmaxCrossEntropyLossInternal to SoftmaxCrossEntropyLossInternal. +*/ +class CastSceLossFusion : public GraphTransformer { + public: + explicit CastSceLossFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("CastSceLossFusion", compatible_execution_providers) { + } + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 9e9261fef6ca1..5a3db9454cc32 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -52,6 +52,7 @@ #include "orttraining/core/framework/distributed_run_context.h" #include "orttraining/core/optimizer/batchnorm_replacement.h" #include "orttraining/core/optimizer/bitmask_dropout_replacement.h" +#include "orttraining/core/optimizer/cast_sce_loss_fusion.h" #include "orttraining/core/optimizer/concat_replacement.h" #include "orttraining/core/optimizer/graph_transformer_registry.h" #include "orttraining/core/optimizer/gru_replacement.h" @@ -188,6 +189,7 @@ std::vector> GeneratePreTrainingTransformers( config.propagate_cast_ops_config.allow, cuda_execution_provider)); } + transformers.emplace_back(std::make_unique(compatible_eps)); if (config.enable_compute_optimizer) { transformers.emplace_back(std::make_unique(compatible_eps)); diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 109937ff96d1d..4ab035a171430 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -25,6 +25,7 @@ #include "test/util/include/asserts.h" #include "orttraining/test/optimizer/horizontal_parallel_test_utils.h" #include "orttraining/core/session/training_session.h" +#include "orttraining/core/optimizer/cast_sce_loss_fusion.h" #include "orttraining/core/optimizer/loss_rewriter.h" #include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h" #include "orttraining/core/optimizer/qdq_fusion.h" @@ -518,6 +519,22 @@ TEST_F(GraphTransformationTests, SceLossGradBiasFusion_Invalid) { } } +TEST_F(GraphTransformationTests, CastSceLossFusion) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "computation_reduction/reshape/mlm_bert_e2e.onnx"; + std::shared_ptr model; + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK()); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Cast"], 10); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Cast"], 9); +} + Node* GetNodeByName(Graph& graph, std::string node_name) { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();