Skip to content

Commit

Permalink
Do not merge Resize with QDQ (fixes microsoft#21319)
Browse files Browse the repository at this point in the history
See microsoft#21319 for details.
This PR disables the QDQ resize matching to avoid numerical issues.
  • Loading branch information
mgehre-amd committed Sep 13, 2024
1 parent 2cdc05f commit 8edf7b0
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
const std::string drop_action_name{"drop"};
const std::string drop_action_no_int16_name{"drop_no_int16_support"};
const std::string drop_action_no_int16_and_positive_scale_name{"drop_no_int16_support_and_positive_scale"};
const std::string drop_action_resize_nearest_name{"drop_resize_nearest"};
NTO::NodeLocation dq{NTO::NodeType::kInput, 0};
NTO::NodeLocation q{NTO::NodeType::kOutput, 0};

Expand All @@ -55,6 +56,8 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::vector<NodeAndMoveInfo>(moves)); // Copy before std::move(moves)
std::unique_ptr<Action> drop_action_no_int16_and_positive_scale = std::make_unique<MergeIntoTargetFixed>(
std::vector<NodeAndMoveInfo>(moves)); // Copy before std::move(moves)
std::unique_ptr<Action> drop_action_resize_nearest = std::make_unique<MergeIntoTargetFixed>(
std::vector<NodeAndMoveInfo>(moves)); // Copy before std::move(moves)
std::unique_ptr<Action> drop_action = std::make_unique<MergeIntoTargetFixed>(std::move(moves));

#if !defined(ORT_MINIMAL_BUILD)
Expand All @@ -67,14 +70,11 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
// And cannot eliminate the QDQ for MaxPool if the scale is not positive, as a negative
// scale will change the ordering of the elements between quantized & de-quantized values.
std::vector<const char*> providers = {kCpuExecutionProvider, kDmlExecutionProvider};
std::unique_ptr<NodeSelector> selector_no_16bit = std::make_unique<QDQ::DropQDQNodesSelector>(false,
false,
true,
providers);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name,
std::unique_ptr<NodeSelector> selector_resize_nearest = std::make_unique<QDQ::DropQDQNodesResizeNearestSelector>(false, false, true, providers);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_resize_nearest_name,
{{"Resize", {}}},
std::move(selector_no_16bit),
std::move(drop_action_no_int16));
std::move(selector_resize_nearest),
std::move(drop_action_resize_nearest));

std::unique_ptr<NodeSelector> selector_no_16bit_and_positive_scale =
std::make_unique<QDQ::DropQDQNodesSelector>(false, true, false, providers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath());
}

bool DropQDQNodeResizeNearestSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!DropQDQNodeGroupSelector::Check(graph_viewer, node, dq_nodes, q_nodes)) {
return false;
}

const onnx::AttributeProto* mode = graph_utils::GetNodeAttribute(node, "mode");
// default mode is 'nearest'
return mode == nullptr || mode->s() == "nearest";
}

bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,27 @@ class DropQDQNodeGroupSelector : public NodeGroupSelector {
bool allow_nonpositive_scale = true)
: allow_16bit_(allow_16bit), allow_4bit_(allow_4bit), allow_nonpositive_scale_(allow_nonpositive_scale) {}

private:
protected:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;

private:
bool allow_16bit_;
bool allow_4bit_;
bool allow_nonpositive_scale_;
};

class DropQDQNodeResizeNearestSelector : public DropQDQNodeGroupSelector {
public:
using DropQDQNodeGroupSelector::DropQDQNodeGroupSelector;

private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
};

// Single DQ -> node.
class DropDQNodeGroupSelector : public NodeGroupSelector {
public:
Expand Down Expand Up @@ -309,6 +320,14 @@ class DropQDQNodesSelector : public BaseSelector {
compatible_providers) {}
};

class DropQDQNodesResizeNearestSelector : public BaseSelector {
public:
explicit DropQDQNodesResizeNearestSelector(bool allow_16bit = false, bool allow_4bit = false, bool allow_nonpositive_scale = true,
gsl::span<const char*> compatible_providers = {})
: BaseSelector(std::make_unique<DropQDQNodeResizeNearestSelector>(allow_16bit, allow_4bit, allow_nonpositive_scale),
compatible_providers) {}
};

class DropDQNodesSelector : public BaseSelector {
public:
explicit DropDQNodesSelector(bool allow_16bit = false,
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,33 @@ TEST(QDQTransformerTests, Resize) {
test_case({2, 13, 12, 37}, rand_gen.Uniform<int64_t>(std::vector<int64_t>{4}, 1, 16), true /*use_contrib_qdq*/);
}

TEST(QDQTransformerTests, ResizeLinearNoFusion) {
auto test_case = [&](bool use_contrib_qdq = false) {
auto check_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count["Resize"], 1);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1);
};

TransformerTester(BuildQDQResizeTestCase({1, 64, 64, 3},
{1, 32, 32, 3},
"linear", // mode
"half_pixel", // coordinate_transformation_mode
"round_prefer_floor", // nearest_mode
false, // add_dq_output_float
use_contrib_qdq),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2);
};

RandomValueGenerator rand_gen{optional<RandomValueGenerator::RandomSeedType>{2345}};
test_case();
test_case(true /*use_contrib_qdq*/);
}

TEST(QDQTransformerTests, Resize_No_Fusion) {
auto test_case = [&](const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& sizes_shape,
Expand Down

0 comments on commit 8edf7b0

Please sign in to comment.