Skip to content

Commit

Permalink
Support old GetClipMinMax params for webnn. Mark as deprecated and ad…
Browse files Browse the repository at this point in the history
…d note to webnn clip_op_builder.cc that IsOpSupportedImpl should be changed to take GraphViewer.
  • Loading branch information
skottmckay committed Feb 4, 2024
1 parent 9997931 commit 80e8f92
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
30 changes: 26 additions & 4 deletions onnxruntime/core/providers/shared/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ bool GetType(const NodeArg& node_arg, int32_t& type, const logging::Logger& logg
return true;
}

bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node,
float& min, float& max, const logging::Logger& logger) {
namespace {
bool GetClipMinMaxImpl(std::function<const ONNX_NAMESPACE::TensorProto*(const std::string&)> get_const_initializer,
const Node& node, float& min, float& max, const logging::Logger& logger) {
const auto& node_name = node.Name();
int32_t input_type;
if (!GetType(*node.InputDefs()[0], input_type, logger)) {
Expand Down Expand Up @@ -70,7 +71,7 @@ bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node,
if (node.InputDefs().size() > 1 && node.InputDefs()[1]->Exists()) {
// we have input min
const auto& min_name = node.InputDefs()[1]->Name();
const auto* min_value = graph_viewer.GetConstantInitializer(min_name);
const auto* min_value = get_const_initializer(min_name);
if (!get_value(min_value, "Min", min)) {
return false;
}
Expand All @@ -79,7 +80,7 @@ bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node,
if (node.InputDefs().size() > 2 && node.InputDefs()[2]->Exists()) {
// we have input max
const auto& max_name = node.InputDefs()[2]->Name();
const auto* max_value = graph_viewer.GetConstantInitializer(max_name);
const auto* max_value = get_const_initializer(max_name);
if (!get_value(max_value, "Max", max)) {
return false;
}
Expand All @@ -88,6 +89,27 @@ bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node,

return true;
}
} // namespace

bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node, float& min, float& max,
const logging::Logger& logger) {
return GetClipMinMaxImpl(
[&graph_viewer](const std::string& name) -> const ONNX_NAMESPACE::TensorProto* {
return graph_viewer.GetConstantInitializer(name);
},
node, min, max, logger);
}

// deprecated version that is not able to check if the initializer is constant
bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node, float& min, float& max,
const logging::Logger& logger) {
return GetClipMinMaxImpl(
[&initializers](const std::string& name) -> const ONNX_NAMESPACE::TensorProto* {
auto entry = initializers.find(name);
return entry == initializers.end() ? nullptr : entry->second;
},
node, min, max, logger);
}

NodeAttrHelper::NodeAttrHelper(const onnxruntime::Node& node)
: node_attributes_(node.GetAttributes()) {}
Expand Down
15 changes: 11 additions & 4 deletions onnxruntime/core/providers/shared/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@ class Node;
class NodeArg;
class NodeUnit;

// Get the min/max of a Clip operator.
// If min/max are not known initializer tensors, will return false
// For now we only support getting float min/max,
// since in most cases, Clip(0,6)[Relu6] will be fused by quantization tool
// Get the min/max of a Clip operator. Reads values from attributes for opset < 11 and inputs after that.
// For opset 11+, if min/max are not constant initializers, will return false.
// For now we only support getting float min/max.
bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node,
float& min, float& max, const logging::Logger& logger);

// Get the min/max of a Clip operator. Reads values from attributes for opset < 11 and inputs after that.
// For opset 11+, if min/max are not initializers, will return false.
// For now we only support getting float min/max.
// Deprecated - use the version that takes a GraphViewer to retrieve initializers so they can be checked to ensure
// they are constant.
[[deprecated]] bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node,
float& min, float& max, const logging::Logger& logger);

// Get the type of the given NodeArg
// Will return false if the given NodeArg has no type
bool GetType(const NodeArg& node_arg, int32_t& type, const logging::Logger& logger);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
// TODO: Update IsOpSupportedImpl to pass GraphViewer instead of InitializedTensorSet so the implementations

Check warning on line 73 in onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc:73: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// can ensure initializers are constant. See #19401 for details of how this update was made to the NNAPI EP.
// GetClipMinMax(graph_viewer, node, minValue, maxValue, logger)
float min, max;
return GetClipMinMax(initializers, node, min, max, logger);
}
Expand Down

0 comments on commit 80e8f92

Please sign in to comment.