Skip to content

Commit

Permalink
Merge branch 'main' of https://www.github.com/microsoft/onnxruntime i…
Browse files Browse the repository at this point in the history
…nto zhanyi/exception
  • Loading branch information
Yi Zhang committed Jun 27, 2024
2 parents eef0542 + 3c0b407 commit 2b62920
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 223 deletions.
10 changes: 5 additions & 5 deletions js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Pow | ai.onnx(7-11, 12, 13-14, 15+) | pow ||| |
| PRelu | ai.onnx(7-8, 9-15, 16+) | prelu ||| WebNN CPU backend restricts the last dimension of input and slope to be same (Chromium issue: https://issues.chromium.org/issues/335517470) |
| Reciprocal | ai.onnx(7-12, 13+) | reciprocal ||| |
| ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | || Input 'axes' if present should be a constant |
| ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | || Input 'axes' if present should be a constant |
| ReduceLogSum| ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSum| || Input 'axes' if present should be a constant |
| ReduceLogSumExp | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSumExp | || Input 'axes' if present should be a constant |
| ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | || Input 'axes' if present should be a constant |
| ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | || Input 'axes' if present should be a constant |
| ReduceLogSum| ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSum| || Input 'axes' if present should be a constant |
| ReduceLogSumExp | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSumExp | || Input 'axes' if present should be a constant |
| ReduceMax | ai.onnx(7-10, 11, 12, 13-17, 18-19, 20+) | reduceMax ||| Input 'axes' if present should be a constant |
| ReduceMean | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceMean ||| Input 'axes' if present should be a constant |
| ReduceMin | ai.onnx(7-10, 11, 12, 13-17, 18-19, 20+) | reduceMin ||| Input 'axes' if present should be a constant |
| ReduceProd | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceProduct ||| Input 'axes' if present should be a constant |
| ReduceSum | ai.onnx(7-10, 11-12, 13+) | reduceSum ||| Input 'axes' if present should be a constant |
| ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | || Input 'axes' if present should be a constant |
| ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | || Input 'axes' if present should be a constant |
| Relu | ai.onnx(7-12, 13, 14+) | relu ||| |
| Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape ||| Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported |
| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d ||| Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, 'linear' and 'nearest' modes |
Expand Down
8 changes: 3 additions & 5 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rocm_blas_alt_impl.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/shape_input_merge.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/transpose_optimizer.h"
Expand Down Expand Up @@ -215,17 +214,16 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<DoubleQDQPairsRemover>());
}

// Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create
// more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output
// or consume different initializers with same value, by default, CSE will not merge them.
// Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for
// CSE. For example, if A and B nodes consume different initializers with same value, by default,
// CSE will not merge them.
InlinedHashSet<std::string> excluded_initializers;
excluded_initializers.reserve(session_options.initializers_to_share_map.size());
for (const auto& p : session_options.initializers_to_share_map) {
excluded_initializers.insert(p.first);
}
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
transformers.emplace_back(std::make_unique<ConstantSharing>(no_limit_empty_ep_list, excluded_initializers));
transformers.emplace_back(std::make_unique<ShapeInputMerge>());
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq,
session_options.config_options));
Expand Down
78 changes: 0 additions & 78 deletions onnxruntime/core/optimizer/shape_input_merge.cc

This file was deleted.

23 changes: 0 additions & 23 deletions onnxruntime/core/optimizer/shape_input_merge.h

This file was deleted.

48 changes: 31 additions & 17 deletions onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,30 +140,44 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span<c

namespace {
void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type,
std::optional<gsl::span<const int64_t>> shape) {
std::optional<gsl::span<const int64_t>> shape, bool convert_scalar = false) {
tensor_type.set_datatype(data_type);
if (shape) {
tensor_type.set_rank(shape->size());
for (const auto& dim : *shape) {
if (dim >= 0) {
tensor_type.add_dimensions()->mutable_constant()->set_size(narrow<int32_t>(dim));
} else {
tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false);
auto rank = shape->size();
if (convert_scalar && rank == 0) {
// CoreML scalar has shape {1}
tensor_type.set_rank(1);
tensor_type.add_dimensions()->mutable_constant()->set_size(1);
} else {
tensor_type.set_rank(rank);
for (const auto& dim : *shape) {
if (dim >= 0) {
tensor_type.add_dimensions()->mutable_constant()->set_size(narrow<int32_t>(dim));
} else {
tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false);
}
}
}
}
}

void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type,
const ONNX_NAMESPACE::TensorShapeProto* shape) {
const ONNX_NAMESPACE::TensorShapeProto* shape, bool convert_scalar = false) {
tensor_type.set_datatype(data_type);
if (shape) {
tensor_type.set_rank(shape->dim_size());
for (const auto& dim : shape->dim()) {
if (dim.has_dim_value()) {
tensor_type.add_dimensions()->mutable_constant()->set_size(narrow<int32_t>(dim.dim_value()));
} else {
tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false);
auto rank = shape->dim_size();
if (convert_scalar && rank == 0) {
// CoreML scalar has shape {1}
tensor_type.set_rank(1);
tensor_type.add_dimensions()->mutable_constant()->set_size(1);
} else {
tensor_type.set_rank(rank);
for (const auto& dim : shape->dim()) {
if (dim.has_dim_value()) {
tensor_type.add_dimensions()->mutable_constant()->set_size(narrow<int32_t>(dim.dim_value()));
} else {
tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false);
}
}
}
}
Expand Down Expand Up @@ -281,13 +295,13 @@ template MILSpec::Value CreateScalarTensorValue(const int32_t& data);
template MILSpec::Value CreateScalarTensorValue(const std::string& data);
template MILSpec::Value CreateScalarTensorValue(const bool& data);

COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg) {
COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg, bool convert_scalar) {
MILSpec::NamedValueType nvt;
nvt.set_name(node_arg.Name());
MILSpec::TensorType& tensor_type = *nvt.mutable_type()->mutable_tensortype();

SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(node_arg.TypeAsProto()->tensor_type().elem_type()),
node_arg.Shape());
node_arg.Shape(), convert_scalar);

return nvt;
}
Expand All @@ -308,7 +322,7 @@ void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& outp
MILSpec::TensorType& tensor_type = *value.mutable_tensortype();

SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(output.TypeAsProto()->tensor_type().elem_type()),
output.Shape());
output.Shape(), /*convert_scalar*/ true);
}

void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ template <typename T>
COREML_SPEC::MILSpec::Value CreateScalarTensorValue(const T& data);

/// <summary>Create a NamedValueType from an ONNX tensor NodeArg.</summary>
/// <param name="node_arg">NodeArg to create NamedValueType from.</param>
/// <param name="convert_scalar">If true, scalar shapes are converted to 1D.</param>
/// <remarks>Used to create inputs for the 'main' function in an ML Program.</remarks>
COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg);
COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg, bool convert_scalar = false);

/// <summary>
/// Add an input argument to a MILSpec::Operation
Expand Down
7 changes: 1 addition & 6 deletions onnxruntime/core/providers/coreml/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -838,13 +838,8 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
if (create_ml_program_) {
if (is_input) {
// the model inputs need to be wired up as args to the 'main' function.
auto tensor_value_type = CreateNamedTensorValueType(node_arg);
auto tensor_value_type = CreateNamedTensorValueType(node_arg, /*convert_scalar*/ true);
tensor_value_type.set_name(name);
if (node_arg.Shape()->dim_size() == 0) {
// update shape from {} to {1} (same change we made at the model input level above).
tensor_value_type.mutable_type()->mutable_tensortype()->set_rank(1);
tensor_value_type.mutable_type()->mutable_tensortype()->add_dimensions()->mutable_constant()->set_size(1);
}

mlprogram_main_fn_->mutable_inputs()->Add(std::move(tensor_value_type));
} else {
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,16 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"Pow", {"pow", true}},
{"PRelu", {"prelu", true}},
{"Reciprocal", {"reciprocal", true}},
{"ReduceL1", {"reduceL1", false}},
{"ReduceL2", {"reduceL2", false}},
{"ReduceLogSum", {"reduceLogSum", false}},
{"ReduceLogSumExp", {"reduceLogSumExp", false}},
{"ReduceL1", {"reduceL1", true}},
{"ReduceL2", {"reduceL2", true}},
{"ReduceLogSum", {"reduceLogSum", true}},
{"ReduceLogSumExp", {"reduceLogSumExp", true}},
{"ReduceMax", {"reduceMax", true}},
{"ReduceMean", {"reduceMean", true}},
{"ReduceMin", {"reduceMin", true}},
{"ReduceProd", {"reduceProduct", true}},
{"ReduceSum", {"reduceSum", true}},
{"ReduceSumSquare", {"reduceSumSquare", false}},
{"ReduceSumSquare", {"reduceSumSquare", true}},
{"Relu", {"relu", true}},
{"Reshape", {"reshape", true}},
{"Resize", {"resample2d", true}},
Expand Down
75 changes: 0 additions & 75 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/shape_input_merge.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/optimizer/utils.h"
Expand Down Expand Up @@ -7691,80 +7690,6 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) {
}
}

TEST_F(GraphTransformationTests, ShapeInputMerge) {
auto build_test_case = [&](ModelTestBuilder& builder) {
std::vector<std::variant<int64_t, std::string>> input_shape;
input_shape.reserve(5);
input_shape.emplace_back("dim0");
input_shape.emplace_back(512);
input_shape.emplace_back(1);
input_shape.emplace_back(1536);
input_shape.emplace_back("dim4");
auto* input_arg = builder.MakeSymbolicInput<float>(input_shape);
auto* neg_out = builder.MakeIntermediate();
auto* axes_initializer = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(2)});
auto* squeeze_out = builder.MakeIntermediate();
auto* cast_out = builder.MakeIntermediate();
auto* unsqueeze_out = builder.MakeOutput();
auto* shape_1_out = builder.MakeOutput();
auto* shape_2_out = builder.MakeOutput();
auto* shape_3_out = builder.MakeOutput();
auto* shape_4_out = builder.MakeOutput();
auto* shape_5_out = builder.MakeOutput();
builder.AddNode("Neg", {input_arg}, {neg_out});
builder.AddNode("Squeeze", {neg_out, axes_initializer}, {squeeze_out});
builder.AddNode("Cast", {squeeze_out}, {cast_out}).AddAttribute("to", static_cast<int64_t>(10));
builder.AddNode("Unsqueeze", {cast_out, axes_initializer}, {unsqueeze_out});
builder.AddNode("Shape", {input_arg}, {shape_1_out});
builder.AddNode("Shape", {neg_out}, {shape_2_out});
builder.AddNode("Shape", {squeeze_out}, {shape_3_out});
builder.AddNode("Shape", {cast_out}, {shape_4_out});
builder.AddNode("Shape", {unsqueeze_out}, {shape_5_out});
};

auto pre_graph_checker = [&](Graph& graph) {
InlinedHashMap<std::string, int> ref_count;
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Shape") {
std::string name = node.InputDefs()[0]->Name();
if (ref_count.find(name) == ref_count.end()) {
ref_count[name] = 1;
} else {
ref_count[name]++;
}
}
}
TEST_RETURN_IF_NOT(ref_count.size() == 5);
return Status::OK();
};

auto post_graph_checker = [&](Graph& graph) {
InlinedHashMap<std::string, int> ref_count;
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Shape") {
std::string name = node.InputDefs()[0]->Name();
if (ref_count.find(name) == ref_count.end()) {
ref_count[name] = 1;
} else {
ref_count[name]++;
}
}
}
TEST_RETURN_IF_NOT(ref_count.size() == 2);
int sum = 0, mul = 1;
for (auto& entry : ref_count) {
sum += entry.second;
mul *= entry.second;
}
TEST_RETURN_IF_NOT(sum == 5 && mul == 6);
return Status::OK();
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<ShapeInputMerge>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}

#if !defined(DISABLE_CONTRIB_OPS)

TEST_F(GraphTransformationTests, MatMulNBitsBiasFusion) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/shape_input_merge.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/unsqueeze_elimination.h"
Expand Down Expand Up @@ -117,11 +116,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique<PythonOpRewriter>()));
#endif

// Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create
// more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output
// or consume different initializers with same value, by default, CSE will not merge them.
// Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for
// CSE. For example, if A and B nodes consume different initializers with same value, by default,
// CSE will not merge them.
transformers.emplace_back(std::make_unique<ConstantSharing>(compatible_eps));
transformers.emplace_back(std::make_unique<ShapeInputMerge>(compatible_eps));
// LayerNormFusion must be applied before CommonSubexpressionElimination as the latter will break the pattern when 2 LayerNormFusion share the same input.
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
// Remove duplicate nodes. Must be applied before any recompute transformations.
Expand Down
Loading

0 comments on commit 2b62920

Please sign in to comment.