From 89723c8612d26d09e0e5995de6f200249035423d Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 23 Nov 2023 01:05:30 +0800 Subject: [PATCH] [WebNN EP] Mark and fallback unsupported op for WebNN CPU backend (#18472) Current WebNN CPU (XNNPack) backend supports limit op list, fallbacks unsupported ops for WebNN "cpu" deviceType directly. This is a workaround because the op may be included in MLGraphBuilder for DirectML backend but without XNNPack implementation in Chromium. --- .../core/providers/webnn/builders/helper.cc | 2 +- .../core/providers/webnn/builders/helper.h | 186 ++++++++++-------- 2 files changed, 105 insertions(+), 83 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 38266f566e6e1..d34cb7e362446 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -85,7 +85,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const auto* node(graph_viewer.GetNode(node_idx)); bool supported = false; // Firstly check if platform supports the WebNN op. - if (CheckSingleOp(node->OpType(), wnn_builder_)) { + if (CheckSingleOp(node->OpType(), wnn_builder_, device_type)) { LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; supported = IsNodeSupported(*node, graph_viewer, device_type, logger); } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 8ae16f0dd21fc..28b54b9c9cf8d 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -30,6 +30,11 @@ enum class WebnnDeviceType { GPU, }; +typedef struct { + std::string opName; + bool isCpuSupported; // The WebNN CPU backend XNNPack supports it (not about the CPU EP). +} WebnnOpInfo; + bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger); template @@ -128,90 +133,107 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const emscripten::val& wnn_builder_, const WebnnDeviceType device_type, const logging::Logger& logger); -static const InlinedHashMap op_map = { - {"Abs", "abs"}, - {"Add", "add"}, - {"ArgMax", "argMax"}, - {"ArgMin", "argMin"}, - {"AveragePool", "averagePool2d"}, - {"BatchNormalization", "meanVarianceNormalization"}, - {"Cast", "cast"}, - {"Ceil", "ceil"}, - {"Clip", "clamp"}, - {"Concat", "concat"}, - {"Conv", "conv2d"}, - {"ConvTranspose", "convTranspose2d"}, - {"Cos", "cos"}, - {"Div", "div"}, - {"Elu", "elu"}, - {"Equal", "equal"}, - {"Erf", "erf"}, - {"Exp", "exp"}, - {"Expand", "expand"}, - {"Flatten", "flattenTo2d"}, - {"Floor", "floor"}, - {"Gather", "gather"}, - {"Gemm", "gemm"}, - {"GlobalAveragePool", "averagePool2d"}, - {"GlobalMaxPool", "maxPool2d"}, - {"GlobalLpPool", "l2Pool2d"}, - {"Greater", "greater"}, - {"GreaterOrEqual", "greaterOrEqual"}, - {"GroupNormalization", "meanVarianceNormalization"}, - {"HardSigmoid", "hardSigmoid"}, - {"HardSwish", "hardSwish"}, - {"Identity", "identity"}, - {"InstanceNormalization", "meanVarianceNormalization"}, - {"LayerNormalization", "meanVarianceNormalization"}, - {"LeakyRelu", "leakyRelu"}, - {"Less", "lesser"}, - {"LessOrEqual", "lesserOrEqual"}, - {"Log", "log"}, - {"LpPool", "l2Pool2d"}, - {"MatMul", "matmul"}, - {"Max", "max"}, - {"MaxPool", "maxPool2d"}, - {"Min", "min"}, - {"Mul", "mul"}, - {"Neg", "neg"}, - {"Not", "logicalNot"}, - {"Pad", "pad"}, - {"Pow", "pow"}, - {"PRelu", "prelu"}, - {"Reciprocal", "reciprocal"}, - {"ReduceL1", "reduceL1"}, - {"ReduceL2", "reduceL2"}, - {"ReduceLogSum", "reduceLogSum"}, - {"ReduceLogSumExp", "reduceLogSumExp"}, - {"ReduceMax", "reduceMax"}, - {"ReduceMean", "reduceMean"}, - {"ReduceMin", "reduceMin"}, - {"ReduceProd", "reduceProduct"}, - {"ReduceSum", "reduceSum"}, - {"ReduceSumSquare", "reduceSumSquare"}, - {"Relu", "relu"}, - {"Reshape", "reshape"}, - {"Resize", "resample2d"}, - {"Shape", "slice"}, - {"Sigmoid", "sigmoid"}, - {"Softplus", "softplus"}, - {"Softsign", "softsign"}, - {"Sin", "sin"}, - {"Slice", "slice"}, - {"Softmax", "softmax"}, - {"Split", "split"}, - {"Sqrt", "sqrt"}, - {"Squeeze", "squeeze"}, - {"Sub", "sub"}, - {"Tan", "tan"}, - {"Tanh", "tanh"}, - {"Transpose", "transpose"}, - {"Unsqueeze", "unsqueeze"}, - {"Where", "elementwiseIf"}, +static const InlinedHashMap op_map = { + {"Abs", {"abs", true}}, + {"Add", {"add", true}}, + {"ArgMax", {"argMax", false}}, + {"ArgMin", {"argMin", false}}, + {"AveragePool", {"averagePool2d", true}}, + {"BatchNormalization", {"meanVarianceNormalization", false}}, + {"Cast", {"cast", false}}, + {"Ceil", {"ceil", true}}, + {"Clip", {"clamp", true}}, + {"Concat", {"concat", true}}, + {"Conv", {"conv2d", true}}, + {"ConvTranspose", {"convTranspose2d", true}}, + {"Cos", {"cos", false}}, + {"Div", {"div", true}}, + {"Elu", {"elu", true}}, + {"Equal", {"equal", false}}, + {"Erf", {"erf", false}}, + {"Exp", {"exp", false}}, + {"Expand", {"expand", false}}, + {"Flatten", {"flattenTo2d", false}}, + {"Floor", {"floor", true}}, + {"Gather", {"gather", false}}, + {"Gemm", {"gemm", true}}, + {"GlobalAveragePool", {"averagePool2d", true}}, + {"GlobalMaxPool", {"maxPool2d", true}}, + {"GlobalLpPool", {"l2Pool2d", false}}, + {"Greater", {"greater", false}}, + {"GreaterOrEqual", {"greaterOrEqual", false}}, + {"GroupNormalization", {"meanVarianceNormalization", false}}, + {"HardSigmoid", {"hardSigmoid", false}}, + {"HardSwish", {"hardSwish", true}}, + {"Identity", {"identity", false}}, + {"InstanceNormalization", {"meanVarianceNormalization", false}}, + {"LayerNormalization", {"meanVarianceNormalization", false}}, + {"LeakyRelu", {"leakyRelu", true}}, + {"Less", {"lesser", false}}, + {"LessOrEqual", {"lesserOrEqual", false}}, + {"Log", {"log", false}}, + {"LpPool", {"l2Pool2d", false}}, + {"MatMul", {"matmul", false}}, + {"Max", {"max", true}}, + {"MaxPool", {"maxPool2d", true}}, + {"Min", {"min", true}}, + {"Mul", {"mul", true}}, + {"Neg", {"neg", true}}, + {"Not", {"logicalNot", false}}, + {"Pad", {"pad", true}}, + {"Pow", {"pow", true}}, + {"PRelu", {"prelu", true}}, + {"Reciprocal", {"reciprocal", false}}, + {"ReduceL1", {"reduceL1", false}}, + {"ReduceL2", {"reduceL2", false}}, + {"ReduceLogSum", {"reduceLogSum", false}}, + {"ReduceLogSumExp", {"reduceLogSumExp", false}}, + {"ReduceMax", {"reduceMax", false}}, + {"ReduceMean", {"reduceMean", true}}, + {"ReduceMin", {"reduceMin", false}}, + {"ReduceProd", {"reduceProduct", false}}, + {"ReduceSum", {"reduceSum", true}}, + {"ReduceSumSquare", {"reduceSumSquare", false}}, + {"Relu", {"relu", true}}, + {"Reshape", {"reshape", true}}, + {"Resize", {"resample2d", true}}, + {"Shape", {"slice", true}}, + {"Sigmoid", {"sigmoid", true}}, + {"Softplus", {"softplus", false}}, + {"Softsign", {"softsign", false}}, + {"Sin", {"sin", false}}, + {"Slice", {"slice", true}}, + {"Softmax", {"softmax", true}}, + {"Split", {"split", true}}, + {"Sqrt", {"sqrt", false}}, + {"Squeeze", {"squeeze", false}}, + {"Sub", {"sub", true}}, + {"Tan", {"tan", false}}, + {"Tanh", {"tanh", true}}, + {"Transpose", {"transpose", true}}, + {"Unsqueeze", {"unsqueeze", false}}, + {"Where", {"elementwiseIf", false}}, }; -inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_) { - return op_map.find(op_type) != op_map.end() && wnn_builder_[op_map.find(op_type)->second].as(); +inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_, + const WebnnDeviceType device_type) { + // Returns false if the op_type is not listed in the op_map. + if (op_map.find(op_type) == op_map.end()) { + return false; + } + // Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser. + if (!wnn_builder_[op_map.find(op_type)->second.opName].as()) { + return false; + } + // The current WebNN CPU (XNNPack) backend supports a limited op list, and we'd rather + // fall back early to the ORT CPU EP rather than fail in the WebNN "cpu" deviceType. + // This is a workaround because the op may be included in MLGraphBuilder for DirectML + // backend but without XNNPack implementation in Chromium. + if (!op_map.find(op_type)->second.isCpuSupported) { + return false; + } + + return true; } constexpr std::array supported_cpu_data_types = {