diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index adfa680878945..ec9881184a426 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -42,6 +42,7 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { const std::string drop_action_name{"drop"}; const std::string drop_action_no_int16_name{"drop_no_int16_support"}; const std::string drop_action_no_int16_and_positive_scale_name{"drop_no_int16_support_and_positive_scale"}; + const std::string drop_action_resize_nearest_name{"drop_resize_nearest"}; NTO::NodeLocation dq{NTO::NodeType::kInput, 0}; NTO::NodeLocation q{NTO::NodeType::kOutput, 0}; @@ -55,6 +56,8 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { std::vector(moves)); // Copy before std::move(moves) std::unique_ptr drop_action_no_int16_and_positive_scale = std::make_unique( std::vector(moves)); // Copy before std::move(moves) + std::unique_ptr drop_action_resize_nearest = std::make_unique( + std::vector(moves)); // Copy before std::move(moves) std::unique_ptr drop_action = std::make_unique(std::move(moves)); #if !defined(ORT_MINIMAL_BUILD) @@ -67,14 +70,11 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { // And cannot eliminate the QDQ for MaxPool if the scale is not positive, as a negative // scale will change the ordering of the elements between quantized & de-quantized values. std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider}; - std::unique_ptr selector_no_16bit = std::make_unique(false, - false, - true, - providers); - qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name, + std::unique_ptr selector_resize_nearest = std::make_unique(false, false, true, providers); + qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_resize_nearest_name, {{"Resize", {}}}, - std::move(selector_no_16bit), - std::move(drop_action_no_int16)); + std::move(selector_resize_nearest), + std::move(drop_action_resize_nearest)); std::unique_ptr selector_no_16bit_and_positive_scale = std::make_unique(false, true, false, providers); @@ -85,6 +85,8 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { std::move(selector_no_16bit_and_positive_scale), std::move(drop_action_no_int16_and_positive_scale)); + + std::unique_ptr selector = std::make_unique(true, false, true, providers); // DepthToSpace and SpaceToDepth not included because there are no integer implementations. // https://github.com/microsoft/onnxruntime/issues/21287 diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 203aba2c3dd91..8ac0bfcc16a0e 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -160,6 +160,20 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); } + +bool DropQDQNodeResizeNearestSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + if (!DropQDQNodeGroupSelector::Check(graph_viewer, node, dq_nodes, q_nodes)) { + return false; + } + + const onnx::AttributeProto* mode = graph_utils::GetNodeAttribute(node, "mode"); + // default mode is 'nearest' + return mode == nullptr || mode->s() == "nearest"; +} + bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 0ba5436e69e81..5b2ca7abb7b11 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -52,16 +52,28 @@ class DropQDQNodeGroupSelector : public NodeGroupSelector { bool allow_nonpositive_scale = true) : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit), allow_nonpositive_scale_(allow_nonpositive_scale) {} - private: + protected: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + private: bool allow_16bit_; bool allow_4bit_; bool allow_nonpositive_scale_; }; +class DropQDQNodeResizeNearestSelector : public DropQDQNodeGroupSelector { + public: + using DropQDQNodeGroupSelector::DropQDQNodeGroupSelector; + + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; + +}; + // Single DQ -> node. class DropDQNodeGroupSelector : public NodeGroupSelector { public: @@ -309,6 +321,15 @@ class DropQDQNodesSelector : public BaseSelector { compatible_providers) {} }; +class DropQDQNodesResizeNearestSelector : public BaseSelector { + public: + explicit DropQDQNodesResizeNearestSelector(bool allow_16bit = false, bool allow_4bit = false, bool allow_nonpositive_scale = true, + gsl::span compatible_providers = {}) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit, allow_nonpositive_scale), + compatible_providers) {} +}; + + class DropDQNodesSelector : public BaseSelector { public: explicit DropDQNodesSelector(bool allow_16bit = false, diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index d07977d4b97b8..2b4c317094e35 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -1924,6 +1924,33 @@ TEST(QDQTransformerTests, Resize) { test_case({2, 13, 12, 37}, rand_gen.Uniform(std::vector{4}, 1, 16), true /*use_contrib_qdq*/); } +TEST(QDQTransformerTests, ResizeLinearNoFusion) { + auto test_case = [&](bool use_contrib_qdq = false) { + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["Resize"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + TransformerTester(BuildQDQResizeTestCase({1, 64, 64, 3}, + {1, 32, 32, 3}, + "linear", // mode + "half_pixel", // coordinate_transformation_mode + "round_prefer_floor", // nearest_mode + false, // add_dq_output_float + use_contrib_qdq), + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2); + }; + + RandomValueGenerator rand_gen{optional{2345}}; + test_case(); + test_case(true /*use_contrib_qdq*/); +} + TEST(QDQTransformerTests, Resize_No_Fusion) { auto test_case = [&](const std::vector& input_shape, const std::vector& sizes_shape,