From e33b08ead13d45e051516927f03b86ff483d410f Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 19 Sep 2024 16:20:40 +0800 Subject: [PATCH] [WebNN EP] Use both MLOperandDescriptor.dimensions and MLOperandDescriptor.shape (#22121) The spec renames MLOperandDescriptor.dimensions to MLOperandDescriptor.shape, in order to support older Chromium versions, we will keep both in WebNN EP for a while. Fixed #22120 --- .../providers/webnn/builders/impl/dropout_op_builder.cc | 1 + .../core/providers/webnn/builders/impl/shape_op_builder.cc | 1 + onnxruntime/core/providers/webnn/builders/model_builder.cc | 7 +++++++ 3 files changed, 9 insertions(+) diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc index 469acbc7a7e18..5434194a214ac 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -63,6 +63,7 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val desc = emscripten::val::object(); desc.set("dataType", "uint8"); desc.set("dimensions", emscripten::val::array(dims)); + desc.set("shape", emscripten::val::array(dims)); const auto num_elements = narrow(Product(mask_shape)); emscripten::val ones_buffer = emscripten::val::global("Uint8Array").new_(num_elements); ones_buffer.call("fill", 1); diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 6b56d2c740f40..360c6588898f1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -33,6 +33,7 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val dims = emscripten::val::array(); dims.call("push", rank); desc.set("dimensions", dims); + desc.set("shape", dims); emscripten::val shape_buffer = emscripten::val::global("BigInt64Array").new_(emscripten::val::array(input_shape)); emscripten::val shape_constant = model_builder.GetBuilder().call("constant", desc, shape_buffer); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index f9f8264b234bb..f92fda8c74717 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -100,7 +100,11 @@ Status ModelBuilder::RegisterInitializers() { [](int64_t dim) -> int32_t { return SafeInt(dim); }); emscripten::val desc = emscripten::val::object(); + // TODO: @Honry, remove all MLOperandDescriptor.dimensions usage in the future. + // MLOperandDescriptor.dimensions is deprecated in WebNN API, we need to keep it + // in WebNN EP for a while to support older Chromium versions. desc.set("dimensions", emscripten::val::array(dims)); + desc.set("shape", emscripten::val::array(dims)); auto data_type = tensor.data_type(); emscripten::val operand = emscripten::val::object(); if (IsSupportedDataType(data_type, wnn_limits_["constant"]["dataTypes"])) { @@ -203,6 +207,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i emscripten::val desc = emscripten::val::object(); desc.set("dimensions", emscripten::val::array(dims)); + desc.set("shape", emscripten::val::array(dims)); int32_t data_type; { // type @@ -303,6 +308,7 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( } desc.set("dimensions", emscripten::val::array(shape)); + desc.set("shape", emscripten::val::array(shape)); emscripten::val operand = emscripten::val::object(); // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached // buffers in JS side. Simply create a copy to fix it. @@ -361,6 +367,7 @@ const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type) { emscripten::val desc = emscripten::val::object(); emscripten::val dims = emscripten::val::array(); desc.set("dimensions", dims); + desc.set("shape", dims); emscripten::val zero_buffer = emscripten::val::undefined(); if (!SetWebnnDataType(desc, data_type)) { ORT_THROW("Unsupported data type: " + std::to_string(data_type));