Skip to content

Commit

Permalink
[WebNN EP] Use opSupportLimits to dynamically check data type support (
Browse files Browse the repository at this point in the history
…microsoft#22025)

- Remove hard code data type checks and use WebNN's opSupportLimits
instead
- Add HasSupportedOutputsImpl for output data type validation
- Get preferred layout info from opSupportLimits
- Move Not op to logical_op_builder.cc because it should be there. This
avoid the inconsistent input names in `unary_op_builder.cc`.
  • Loading branch information
Honry authored Sep 14, 2024
1 parent a89bddd commit c63dd02
Show file tree
Hide file tree
Showing 32 changed files with 281 additions and 635 deletions.
61 changes: 53 additions & 8 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ bool GetShape(const NodeArg& node_arg, std::vector<int64_t>& shape, const loggin
return true;
}

bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer,
const WebnnDeviceType device_type, const logging::Logger& logger) {
bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type,
const emscripten::val& wnn_limits, const logging::Logger& logger) {
const auto& op_builders = GetOpBuilders();
if (Contains(op_builders, node.OpType())) {
const auto* op_builder = op_builders.at(node.OpType());
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, logger);
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger);
} else {
return false;
}
Expand Down Expand Up @@ -86,6 +86,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger) {
std::vector<std::vector<size_t>> supported_node_groups;

Expand All @@ -105,7 +106,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser";
supported = IsNodeSupported(*node, graph_viewer, device_type, logger);
supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger);
}

LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
Expand All @@ -130,10 +131,54 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
return supported_node_groups;
}

bool IsSupportedDataType(const int32_t data_type,
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types) {
return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) !=
supported_data_types.end();
bool AreInputDataTypesSame(const std::string& op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger) {
for (size_t i = 1; i < input_types.size(); i++) {
if (input_types[0] != input_types[i]) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same, but ["
<< input_types[0] << "] does not match "
<< input_types[i] << "].";
return false;
}
}
return true;
}

bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) {
auto it = onnx_to_webnn_data_type_map.find(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_data_type));
if (it == onnx_to_webnn_data_type_map.end())
return false;

std::string webnn_data_type = it->second;

// Check if WebNN supports the data type.
emscripten::val is_supported = webnn_supported_data_types.call<emscripten::val>("includes",
emscripten::val(webnn_data_type));
return is_supported.as<bool>();
}

// Check if the input or output data type of ONNX node is supported by the WebNN operator.
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger) {
std::string webnn_op_type;
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
return false;

if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type
<< "] " << onnx_input_output_name
<< " type: [" << onnx_data_type
<< "] is not supported for now";
return false;
}

return true;
}

bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
Expand Down
43 changes: 31 additions & 12 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger);
static const InlinedHashMap<std::string, std::string> op_map = {
{"Abs", "abs"},
Expand Down Expand Up @@ -250,20 +251,38 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn
return true;
}

static const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> webnn_supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
ONNX_NAMESPACE::TensorProto_DataType_UINT64,
inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) {
auto it = op_map.find(op_type);
// Returns false if the op_type is not listed in the op_map.
if (it == op_map.end()) {
return false;
}
webnn_op_type = it->second;
return true;
}

static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> onnx_to_webnn_data_type_map = {
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
{ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"},
{ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"},
{ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
};

bool IsSupportedDataType(const int32_t data_type,
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types);
bool AreInputDataTypesSame(const std::string& op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger);
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger);

bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
std::vector<int64_t>& shape_b,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ class ActivationOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -94,44 +92,6 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi
return true;
}

bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;

std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
// WebNN relu op supports float32, float16, int32, int8 input data types.
if (op_type == "Relu") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
};
// WebNN CPU backend does not support int32 data type for relu.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
}
} else { // Others only support float32 and float16.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
}

if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}

return true;
}

void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -77,31 +75,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
return true;
}

bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;

std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}

if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}

return true;
}

void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;
Expand Down
52 changes: 25 additions & 27 deletions onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node
Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {
ORT_RETURN_IF_NOT(
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), logger),
"Unsupported operator ",
node.OpType());
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(),
model_builder.GetOpSupportLimits(), logger),
"Unsupported operator ", node.OpType());
ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger));
LOGS(logger, VERBOSE) << "Operator name: [" << node.Name()
<< "] type: [" << node.OpType() << "] was added";
Expand All @@ -50,8 +50,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&
// Operator support related.

bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const {
if (!HasSupportedInputs(node, device_type, logger))
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
if (!HasSupportedInputs(node, wnn_limits, logger))
return false;

if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
return false;

// We do not support external initializers for now.
Expand All @@ -64,7 +68,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
return IsOpSupportedImpl(initializers, node, device_type, logger);
}

bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType device_type,
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
Expand All @@ -73,39 +77,33 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d
}
}

// WebNN CPU backend (TFLite) will enable float16 input data type soon,
// temporarily fallback float16 input data type for WebNN CPU.
if (device_type == WebnnDeviceType::CPU) {
const auto& input = *node.InputDefs()[0];

int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
return false;
}

return HasSupportedInputsImpl(node, device_type, logger);
return HasSupportedInputsImpl(node, wnn_limits, logger);
}

bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
const WebnnDeviceType /* device_type */,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.
const auto& input = *node.InputDefs()[0];

const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;

if (!IsSupportedDataType(input_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << node.OpType()
<< "] Input type: [" << input_type
<< "] is not supported for now";
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
}

bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
// We only check the type of output 0 by default, specific op builder can override this.
const auto& output = *node.OutputDefs()[0];
const auto& op_type = node.OpType();
int32_t output_type;
if (!GetType(output, output_type, logger))
return false;
}

return true;
return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger);
}

bool BaseOpBuilder::HasSupportedOpSet(const Node& node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ class BaseOpBuilder : public IOpBuilder {
// Operator support related.
public:
bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;

protected:
virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */,
const WebnnDeviceType /* device_type */, const logging::Logger& /* logger */) const {
return true;
}

virtual bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;

// ONNX Runtime only *guarantees* support for models stamped
// with opset version 7 or above for opset domain 'ai.onnx'.
Expand All @@ -50,7 +53,7 @@ class BaseOpBuilder : public IOpBuilder {

private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
};

} // namespace webnn
Expand Down
Loading

0 comments on commit c63dd02

Please sign in to comment.