Skip to content

Commit

Permalink
Addressed PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
sumitsays committed Nov 1, 2023
1 parent 75425d7 commit 362bf43
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions onnxruntime/core/optimizer/pad_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log
return false;
}

// constant_value should be an initializer because we have to verify the constant_value should be zero.
// It is because Conv and MaxPool allow only 0 as padding value.
if (node.InputDefs().size() > 2 && !graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[2])) {
return false;
}
// Since opset 11, <pads> and <constant_value> moved to inputs.
// Both of these should be initializer because we have to verify the values.
if (node.SinceVersion() >= 11) {
if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) ||
(node.InputDefs().size() > 2 && !graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[2]))) {
return false;
}

// Since opset 11, pad constant value becomes part of input instead of attribute.
if (node.InputDefs().size() > 2) {
// constant_value should be zero because Conv and MaxPool allow only 0 as padding value.
const auto* pad_constant_value_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name());
Initializer pad_constant_value{*pad_constant_value_proto, graph.ModelPath()};
if (std::any_of(pad_constant_value.DataAsByteSpan().begin(), pad_constant_value.DataAsByteSpan().end(), [](const uint8_t byte) { return byte != 0; })) {

Check warning on line 46 in onnxruntime/core/optimizer/pad_fusion.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/optimizer/pad_fusion.cc#L46

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/optimizer/pad_fusion.cc:46:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Expand All @@ -57,6 +58,22 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) {
return false;
}

// Don't fuse if MaxPool has optional output indices tensor because output indices tensor

Check warning on line 62 in onnxruntime/core/optimizer/pad_fusion.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/optimizer/pad_fusion.cc#L62

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/core/optimizer/pad_fusion.cc:62:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
// does not incorporate pad values. Basically if we allow the fusion, then dimension values

Check warning on line 63 in onnxruntime/core/optimizer/pad_fusion.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/optimizer/pad_fusion.cc#L63

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/core/optimizer/pad_fusion.cc:63:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
// of input tensor < dimension values of input tensor without fusion.

Check warning on line 64 in onnxruntime/core/optimizer/pad_fusion.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/optimizer/pad_fusion.cc#L64

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/core/optimizer/pad_fusion.cc:64:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
// This will cause the range of values for output indices tensor to be less than what it

Check warning on line 65 in onnxruntime/core/optimizer/pad_fusion.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/optimizer/pad_fusion.cc#L65

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/core/optimizer/pad_fusion.cc:65:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
// should have been.

if (child_node.OutputDefs().size() > 1) {
return false;
}

// conv or maxpool node must use explicit padding to perform this fusion.
if (child_node.GetAttributes().find("auto_pad") != child_node.GetAttributes().end() &&
child_node.GetAttributes().at("auto_pad").s() != "NOTSET") {
return false;
}
return true;
}

Expand All @@ -74,6 +91,9 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef
pads_values.assign(pad_node.GetAttributes().at("pads").ints().begin(), pad_node.GetAttributes().at("pads").ints().end());

Check warning on line 91 in onnxruntime/core/optimizer/pad_fusion.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/optimizer/pad_fusion.cc#L91

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/optimizer/pad_fusion.cc:91:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}

uint32_t input_rank = static_cast<uint32_t>(pad_node.InputDefs()[0]->Shape()->dim_size());
assert(static_cast<uint32_t>(pads_values.size()) == (2 * input_rank));

uint32_t pads_size = static_cast<uint32_t>(pads_values.size());
// check if padding is applied only on feature dims
if (pads_values[0] != 0 || pads_values[1] != 0 || pads_values[pads_size / 2] != 0 ||
Expand All @@ -92,7 +112,9 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef

for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) {
child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]);
child_pads->Set(child_index + (child_pads_size / 2), child_pads->Get(child_index + (child_pads_size / 2)) + pads_values[pads_index + (pads_size / 2)]);
uint32_t mirrored_child_index = child_index + (child_pads_size / 2);
uint32_t mirrored_pad_index = pads_index + (pads_size / 2);
child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]);
}

graph_utils::RemoveNodeOutputEdges(graph, pad_node);
Expand Down

0 comments on commit 362bf43

Please sign in to comment.