Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some Shape Related Fusions #19832

Merged
merged 5 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions onnxruntime/core/optimizer/common_subexpression_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "common_subexpression_elimination.h"
#include "core/optimizer/utils.h"
#include "core/graph/graph_utils.h"
#include "core/framework/tensorprotoutils.h"

#include <memory>
#include <type_traits>
Expand Down Expand Up @@ -170,6 +171,32 @@ bool AreRangesEqual(const Range& lhs, const Range& rhs) {
std::equal(lhs.begin(), lhs.end(), rhs.begin());
}

// Check if two tensor attributes are equal scalar tensors, mainly to support ConstantOfShape Op.
// Currently support float, float16 and int64 data types, and requires the data are raw data in TensorProto.
bool AreScalarTensorAttributeEqual(const ONNX_NAMESPACE::TensorProto& lhs_t, const ONNX_NAMESPACE::TensorProto& rhs_t) {
if (!(utils::HasDataType(lhs_t) && utils::HasDataType(rhs_t) && lhs_t.data_type() == rhs_t.data_type() &&
(lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT ||
lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT16 ||
lhs_t.data_type() == onnx::TensorProto_DataType_INT64) &&
lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1 &&
utils::HasRawData(lhs_t) && utils::HasRawData(rhs_t))) {
return false;
}
const void* lhs_value = lhs_t.raw_data().data();
const void* rhs_value = rhs_t.raw_data().data();
switch (lhs_t.data_type()) {
case onnx::TensorProto_DataType_FLOAT:
zhijxu-MS marked this conversation as resolved.
Show resolved Hide resolved
return *reinterpret_cast<const float*>(lhs_value) == *reinterpret_cast<const float*>(rhs_value);
case onnx::TensorProto_DataType_FLOAT16:
return *reinterpret_cast<const MLFloat16*>(lhs_value) == *reinterpret_cast<const MLFloat16*>(rhs_value);
case onnx::TensorProto_DataType_INT64:
return *reinterpret_cast<const int64_t*>(lhs_value) == *reinterpret_cast<const int64_t*>(rhs_value);
default:
break;
}
return false;
}

bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::AttributeProto& rhs) {
if (&lhs == &rhs) {
return true;
Expand All @@ -193,6 +220,7 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A
case onnx::AttributeProto_AttributeType_STRINGS:
return AreRangesEqual(lhs.strings(), rhs.strings());
case onnx::AttributeProto_AttributeType_TENSOR:
return AreScalarTensorAttributeEqual(lhs.t(), rhs.t());
case onnx::AttributeProto_AttributeType_GRAPH:
case onnx::AttributeProto_AttributeType_SPARSE_TENSOR:
case onnx::AttributeProto_AttributeType_TYPE_PROTO:
Expand All @@ -207,6 +235,31 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A
return false;
}

// Support scalar float/int64/fp16 tensor attribute only for now, and requires data is raw data in TensorProto.
std::size_t GetTensorAttributeHash(const ONNX_NAMESPACE::TensorProto& attr_t) {
std::size_t hash = 0;
if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1 && utils::HasRawData(attr_t)) {
int data_type = attr_t.data_type();
switch (data_type) {
case onnx::TensorProto_DataType_FLOAT:
UpdateHash(data_type, hash);
UpdateHash(*reinterpret_cast<const float*>(attr_t.raw_data().data()), hash);
break;
case onnx::TensorProto_DataType_FLOAT16:
UpdateHash(data_type, hash);
UpdateHash(static_cast<float>(*reinterpret_cast<const MLFloat16*>(attr_t.raw_data().data())), hash);
break;
case onnx::TensorProto_DataType_INT64:
UpdateHash(data_type, hash);
UpdateHash(*reinterpret_cast<const int64_t*>(attr_t.raw_data().data()), hash);
break;
default:
break;
}
}
return hash;
}

std::size_t GetAttributeHash(const ONNX_NAMESPACE::AttributeProto& attr) {
std::size_t hash = 0;
UpdateHash(
Expand All @@ -233,6 +286,8 @@ std::size_t GetAttributeHash(const ONNX_NAMESPACE::AttributeProto& attr) {
UpdateHashWithContainer(attr.strings(), hash);
break;
case onnx::AttributeProto_AttributeType_TENSOR:
UpdateHash(attr.t(), &GetTensorAttributeHash, hash);
break;
case onnx::AttributeProto_AttributeType_GRAPH:
case onnx::AttributeProto_AttributeType_SPARSE_TENSOR:
case onnx::AttributeProto_AttributeType_TYPE_PROTO:
Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rocm_blas_alt_impl.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/shape_input_merge.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/transpose_optimizer.h"
Expand Down Expand Up @@ -211,17 +212,17 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<DoubleQDQPairsRemover>());
}

// Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for
// CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by
// default, CSE will not merge them, because the different initializers are represented by different NodeArg.
// Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create
// more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output
// or consume different initializers with same value, by default, CSE will not merge them.
InlinedHashSet<std::string> excluded_initializers;
excluded_initializers.reserve(session_options.initializers_to_share_map.size());
for (const auto& p : session_options.initializers_to_share_map) {
excluded_initializers.insert(p.first);
}
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
transformers.emplace_back(std::make_unique<ConstantSharing>(no_limit_empty_ep_list, excluded_initializers));

transformers.emplace_back(std::make_unique<ShapeInputMerge>());
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq,
session_options.config_options));
Expand Down
78 changes: 78 additions & 0 deletions onnxruntime/core/optimizer/shape_input_merge.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/optimizer/shape_input_merge.h"

#include "core/graph/graph_utils.h"

namespace onnxruntime {

namespace {
std::string GetShapeString(const NodeArg* input_arg) {
auto shape = input_arg->Shape();
if (!shape) return "";
std::stringstream ss;
ss << "[";
for (int i = 0; i < shape->dim_size(); ++i) {
if (i != 0) ss << ",";
auto dim = shape->dim(i);
if (dim.has_dim_value()) {
ss << std::to_string(dim.dim_value());
} else if (dim.has_dim_param()) {
ss << "'" << dim.dim_param() << "'";
} else {
return "";
}
}
ss << "]";
return ss.str();
}

} // namespace

Status ShapeInputMerge::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedHashMap<std::string, InlinedVector<Node*>> input_hash_to_nodes;
for (auto node_index : node_topology_list) {
auto* p_node = graph.GetNode(node_index);
if (!p_node) continue; // we removed the node as part of an earlier fusion
ORT_RETURN_IF_ERROR(Recurse(*p_node, modified, graph_level, logger));
centwang marked this conversation as resolved.
Show resolved Hide resolved
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*p_node, "Shape", {1, 13, 15, 19, 21}) ||
!graph_utils::IsSupportedProvider(*p_node, GetCompatibleExecutionProviders())) {
continue;
}
std::string shape_str = GetShapeString(p_node->InputDefs()[0]);

Check warning on line 45 in onnxruntime/core/optimizer/shape_input_merge.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/shape_input_merge.cc:45: Add #include <string> for string [build/include_what_you_use] [4]
if (shape_str.empty()) continue;
if (input_hash_to_nodes.find(shape_str) == input_hash_to_nodes.end()) {
input_hash_to_nodes[shape_str] = InlinedVector<Node*>();
}
input_hash_to_nodes[shape_str].emplace_back(p_node);
}

// All Shape nodes are processed in topological order, so we can safely merge the inputs to the first node's input.
for (auto& kv : input_hash_to_nodes) {
if (kv.second.size() < 2) continue;
NodeArg* first_input_arg = kv.second[0]->MutableInputDefs()[0];
bool is_first_input_arg_graph_input = graph.IsInputsIncludingInitializers(first_input_arg);
for (size_t i = 1; i < kv.second.size(); ++i) {
Node* p_node = kv.second[i];
const NodeArg* input_arg = p_node->InputDefs()[0];
if (p_node->InputDefs()[0]->Name() == first_input_arg->Name()) continue;
if (!graph.IsInputsIncludingInitializers(input_arg)) {
const Node::EdgeEnd& input_edge = *p_node->InputEdgesBegin();
graph.RemoveEdge(input_edge.GetNode().Index(), p_node->Index(), input_edge.GetSrcArgIndex(), 0);
}
graph_utils::ReplaceNodeInput(*p_node, 0, *first_input_arg);
if (!is_first_input_arg_graph_input) {
const Node::EdgeEnd& first_input_edge = *kv.second[0]->InputEdgesBegin();
graph.AddEdge(first_input_edge.GetNode().Index(), p_node->Index(), first_input_edge.GetSrcArgIndex(), 0);
}
modified = true;
}
}

return Status::OK();
}

} // namespace onnxruntime
23 changes: 23 additions & 0 deletions onnxruntime/core/optimizer/shape_input_merge.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/graph_transformer.h"

namespace onnxruntime {

/**
@Class ShapeInputMerge
Merge all shape inputs having same shape value to a single shape input.
This change will not affect the performance, but it open chances for CSE fusion to merge nodes.
*/
class ShapeInputMerge : public GraphTransformer {
public:
ShapeInputMerge(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept

Check warning on line 17 in onnxruntime/core/optimizer/shape_input_merge.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Constructors callable with one argument should be marked explicit. [runtime/explicit] [5] Raw Output: onnxruntime/core/optimizer/shape_input_merge.h:17: Constructors callable with one argument should be marked explicit. [runtime/explicit] [5]
: GraphTransformer("ShapeInputMerge", compatible_execution_providers) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};

} // namespace onnxruntime
5 changes: 3 additions & 2 deletions onnxruntime/core/optimizer/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,16 @@ int32_t IndexOfNodeOutput(const Node& node, const NodeArg& node_arg) {
// We could also allow other known domains (kMSDomain, kMSNchwcDomain, kMSFeaturizersDomain),
// as long as we verify which of their operations are non-deterministic and add them in the map below.
constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNormal", "RandomUniformLike",
"RandomNormalLike", "Multinomial"};
"RandomNormalLike", "Multinomial", "Dropout"};

// List of deterministic MS domain operators. Currently used for constant folding and common subexpression elimination.
//
// TODO(adrianlizarraga): Investigate converting to lists of *non-deterministic* MS domain operators to be consistent
// with the above ONNX list. With the current approach, only MS domain Q/DQ operators
// (plus ShrunkenGather for training) are considered deterministic.
#ifdef ENABLE_TRAINING_OPS
constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear"};
constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear",
"ConcatTraining"};
#else
constexpr std::array kMSDomainDeterministicOps{"QuantizeLinear", "DequantizeLinear"};
centwang marked this conversation as resolved.
Show resolved Hide resolved
#endif
Expand Down
122 changes: 122 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/shape_input_merge.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/optimizer/utils.h"
Expand Down Expand Up @@ -4879,6 +4880,53 @@ TEST_F(GraphTransformationTests, FastGeluFusionWithCastsTest3) {
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
}

TEST_F(GraphTransformationTests, CseWithConstantOfShape) {
auto build_test_case = [&](ModelTestBuilder& builder) {
std::vector<std::variant<int64_t, std::string>> input_shape;
input_shape.reserve(4);
input_shape.emplace_back("dim0");
input_shape.emplace_back(512);
input_shape.emplace_back(16);
input_shape.emplace_back("dim3");
auto* input_arg = builder.MakeSymbolicInput<float>(input_shape);
auto* shape_out_1 = builder.MakeIntermediate();
auto* shape_out_2 = builder.MakeIntermediate();
auto* constant_of_shape_out_1 = builder.MakeIntermediate();
auto* constant_of_shape_out_2 = builder.MakeIntermediate();
auto* mul_out_1 = builder.MakeIntermediate();
auto* mul_out_2 = builder.MakeOutput();
builder.AddNode("Shape", {input_arg}, {shape_out_1});
builder.AddNode("Shape", {input_arg}, {shape_out_2});
TensorProto value_tensor;
value_tensor.add_dims(1);
float value = 2.333f;
value_tensor.set_raw_data(reinterpret_cast<const char*>(&value), sizeof(float));
value_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
builder.AddNode("ConstantOfShape", {shape_out_1}, {constant_of_shape_out_1}).AddAttribute("value", value_tensor);
builder.AddNode("ConstantOfShape", {shape_out_2}, {constant_of_shape_out_2}).AddAttribute("value", value_tensor);
builder.AddNode("Mul", {input_arg, constant_of_shape_out_1}, {mul_out_1});
builder.AddNode("Mul", {mul_out_1, constant_of_shape_out_2}, {mul_out_2});
};

auto pre_graph_checker = [&](Graph& graph) {
auto op_count_map = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_count_map["Shape"] == 2);
TEST_RETURN_IF_NOT(op_count_map["ConstantOfShape"] == 2);
return Status::OK();
};

auto post_graph_checker = [&](Graph& graph) {
auto op_count_map = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_count_map["Shape"] == 1);
TEST_RETURN_IF_NOT(op_count_map["ConstantOfShape"] == 1);
return Status::OK();
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<CommonSubexpressionElimination>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}

TEST_F(GraphTransformationTests, QuickGelu) {
// Sigmoid(x*alpha)*x, float
{
Expand Down Expand Up @@ -7543,5 +7591,79 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) {
}
}

TEST_F(GraphTransformationTests, ShapeInputMerge) {
auto build_test_case = [&](ModelTestBuilder& builder) {
std::vector<std::variant<int64_t, std::string>> input_shape;
input_shape.reserve(5);
input_shape.emplace_back("dim0");
input_shape.emplace_back(512);
input_shape.emplace_back(1);
input_shape.emplace_back(1536);
input_shape.emplace_back("dim4");
auto* input_arg = builder.MakeSymbolicInput<float>(input_shape);
auto* neg_out = builder.MakeIntermediate();
auto* axes_initializer = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(2)});
auto* squeeze_out = builder.MakeIntermediate();
auto* cast_out = builder.MakeIntermediate();
auto* unsqueeze_out = builder.MakeOutput();
auto* shape_1_out = builder.MakeOutput();
auto* shape_2_out = builder.MakeOutput();
auto* shape_3_out = builder.MakeOutput();
auto* shape_4_out = builder.MakeOutput();
auto* shape_5_out = builder.MakeOutput();
builder.AddNode("Neg", {input_arg}, {neg_out});
builder.AddNode("Squeeze", {neg_out, axes_initializer}, {squeeze_out});
builder.AddNode("Cast", {squeeze_out}, {cast_out}).AddAttribute("to", static_cast<int64_t>(10));
builder.AddNode("Unsqueeze", {cast_out, axes_initializer}, {unsqueeze_out});
builder.AddNode("Shape", {input_arg}, {shape_1_out});
builder.AddNode("Shape", {neg_out}, {shape_2_out});
builder.AddNode("Shape", {squeeze_out}, {shape_3_out});
builder.AddNode("Shape", {cast_out}, {shape_4_out});
builder.AddNode("Shape", {unsqueeze_out}, {shape_5_out});
};

auto pre_graph_checker = [&](Graph& graph) {
InlinedHashMap<std::string, int> ref_count;
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Shape") {
std::string name = node.InputDefs()[0]->Name();
if (ref_count.find(name) == ref_count.end()) {
ref_count[name] = 1;
} else {
ref_count[name]++;
}
}
}
TEST_RETURN_IF_NOT(ref_count.size() == 5);
return Status::OK();
};

auto post_graph_checker = [&](Graph& graph) {
InlinedHashMap<std::string, int> ref_count;
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Shape") {
std::string name = node.InputDefs()[0]->Name();
if (ref_count.find(name) == ref_count.end()) {
ref_count[name] = 1;
} else {
ref_count[name]++;
}
}
}
TEST_RETURN_IF_NOT(ref_count.size() == 2);
int sum = 0, mul = 1;
for (auto& entry : ref_count) {
sum += entry.second;
mul *= entry.second;
}
TEST_RETURN_IF_NOT(sum == 5 && mul == 6);
return Status::OK();
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<ShapeInputMerge>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}

} // namespace test
} // namespace onnxruntime
Loading
Loading