Skip to content

Commit

Permalink
[WebNN EP] Add support for Op MatMul of WebNN CPU backend (#19413)
Browse files Browse the repository at this point in the history
Enable MatMul support for WebNN CPU backend to support more models.
  • Loading branch information
zesongw authored Feb 27, 2024
1 parent 1c468a0 commit 2e4d1b8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"LessOrEqual", {"lesserOrEqual", false}},
{"Log", {"log", false}},
{"LpPool", {"l2Pool2d", false}},
{"MatMul", {"matmul", false}},
{"MatMul", {"matmul", true}},
{"MatMulInteger", {"matmulInteger", false}},
{"Max", {"max", true}},
{"MaxPool", {"maxPool2d", true}},
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class GemmOpBuilder : public BaseOpBuilder {

// Add operator related.
Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& /* logger */) const {
const logging::Logger& logger) const {
const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C
Expand All @@ -38,7 +38,17 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name());
emscripten::val output = emscripten::val::object();
if (op_type == "MatMul") {
output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b);
std::vector<int64_t> a_shape;
if (!GetShape(*input_defs[a_idx], a_shape, logger)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Can not get shape of A.");
}
// The inputs of MatMul must be at least 3D for WebNN CPU backend. Use GEMM for 2D case.
// TODO: Remove this workaround when it is fixed in Chromium.
if (model_builder.GetWebnnDeviceType() == WebnnDeviceType::CPU && a_shape.size() == 2) {
output = model_builder.GetBuilder().call<emscripten::val>("gemm", a, b);
} else {
output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b);
}
} else if (op_type == "MatMulInteger") {
emscripten::val a_zero_point = emscripten::val::null();
emscripten::val b_zero_point = emscripten::val::null();
Expand Down

0 comments on commit 2e4d1b8

Please sign in to comment.