From 2e4d1b8f1ba928fe5879077eced9cd5191760cfb Mon Sep 17 00:00:00 2001 From: zesongw Date: Wed, 28 Feb 2024 02:01:12 +0800 Subject: [PATCH] [WebNN EP] Add support for Op MatMul of WebNN CPU backend (#19413) Enable MatMul support for WebNN CPU backend to support more models. --- onnxruntime/core/providers/webnn/builders/helper.h | 2 +- .../webnn/builders/impl/gemm_op_builder.cc | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index d94729e60d029..d7892fe02c1ba 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -195,7 +195,7 @@ static const InlinedHashMap op_map = { {"LessOrEqual", {"lesserOrEqual", false}}, {"Log", {"log", false}}, {"LpPool", {"l2Pool2d", false}}, - {"MatMul", {"matmul", false}}, + {"MatMul", {"matmul", true}}, {"MatMulInteger", {"matmulInteger", false}}, {"Max", {"max", true}}, {"MaxPool", {"maxPool2d", true}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 4bf991a1b0105..d5f84f853f7de 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -29,7 +29,7 @@ class GemmOpBuilder : public BaseOpBuilder { // Add operator related. Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /* logger */) const { + const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C @@ -38,7 +38,17 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name()); emscripten::val output = emscripten::val::object(); if (op_type == "MatMul") { - output = model_builder.GetBuilder().call("matmul", a, b); + std::vector a_shape; + if (!GetShape(*input_defs[a_idx], a_shape, logger)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Can not get shape of A."); + } + // The inputs of MatMul must be at least 3D for WebNN CPU backend. Use GEMM for 2D case. + // TODO: Remove this workaround when it is fixed in Chromium. + if (model_builder.GetWebnnDeviceType() == WebnnDeviceType::CPU && a_shape.size() == 2) { + output = model_builder.GetBuilder().call("gemm", a, b); + } else { + output = model_builder.GetBuilder().call("matmul", a, b); + } } else if (op_type == "MatMulInteger") { emscripten::val a_zero_point = emscripten::val::null(); emscripten::val b_zero_point = emscripten::val::null();