diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 774df067fe347..38266f566e6e1 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -142,5 +142,43 @@ bool IsValidMultidirectionalBroadcast(std::vector& shape_a, return true; } +bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) { + // WebNN changed the name of the MLOperandDescriptor's data type from "type" to "dataType", + // use a duplicate entry temporarily to workaround this API breaking issue. + // TODO: Remove legacy "type" once all browsers implement the new "dataType". + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + desc.set("type", emscripten::val("uint8")); + desc.set("dataType", emscripten::val("uint8")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + desc.set("type", emscripten::val("float16")); + desc.set("dataType", emscripten::val("float16")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + desc.set("type", emscripten::val("float32")); + desc.set("dataType", emscripten::val("float32")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + desc.set("type", emscripten::val("int32")); + desc.set("dataType", emscripten::val("int32")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + desc.set("type", emscripten::val("int64")); + desc.set("dataType", emscripten::val("int64")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + desc.set("type", emscripten::val("uint32")); + desc.set("dataType", emscripten::val("uint32")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + desc.set("type", emscripten::val("uint64")); + desc.set("dataType", emscripten::val("uint64")); + return true; + default: + return false; + } +} + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index cdad9b22a8ab8..46c456556e016 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -231,5 +231,8 @@ bool IsSupportedDataType(const int32_t data_type, const WebnnDeviceType device_t bool IsValidMultidirectionalBroadcast(std::vector& shape_a, std::vector& shape_b, const logging::Logger& logger); + +bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type); + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 04e6d2b548aba..12c2cf6dd0a62 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -34,7 +34,7 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto rank = static_cast(input_shape.size()); emscripten::val desc = emscripten::val::object(); - desc.set("type", emscripten::val("int64")); + ORT_RETURN_IF_NOT(SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_INT64), "Unsupported data type"); emscripten::val dims = emscripten::val::array(); dims.call("push", rank); desc.set("dimensions", dims); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 2eae8cebbbd66..0ac9fb7ff380d 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -122,6 +122,7 @@ Status ModelBuilder::RegisterInitializers() { auto data_type = tensor.data_type(); emscripten::val operand = emscripten::val::object(); if (IsSupportedDataType(data_type, wnn_device_type_)) { + ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); unpacked_tensors_.push_back({}); std::vector& unpacked_tensor = unpacked_tensors_.back(); ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); @@ -129,37 +130,30 @@ Status ModelBuilder::RegisterInitializers() { emscripten::val view = emscripten::val::undefined(); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - desc.set("type", emscripten::val("uint8")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - desc.set("type", emscripten::val("float16")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - desc.set("type", emscripten::val("float32")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: - desc.set("type", emscripten::val("int32")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_INT64: - desc.set("type", emscripten::val("int64")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - desc.set("type", emscripten::val("uint32")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - desc.set("type", emscripten::val("uint64")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; @@ -238,35 +232,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i } data_type = type_proto->tensor_type().elem_type(); - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - desc.set("type", emscripten::val("uint8")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - desc.set("type", emscripten::val("float16")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - desc.set("type", emscripten::val("float32")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - desc.set("type", emscripten::val("int32")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - desc.set("type", emscripten::val("int64")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - desc.set("type", emscripten::val("uint32")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - desc.set("type", emscripten::val("uint64")); - break; - default: { - // TODO: support other type. - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "The ", input_output_type, " of graph doesn't have valid type, name: ", name, - " type: ", type_proto->tensor_type().elem_type()); - } - } + ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); } if (is_input) { @@ -316,41 +282,35 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( memcpy(dest, buffer, size); emscripten::val view = emscripten::val::undefined(); emscripten::val desc = emscripten::val::object(); + ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("uint8")); break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint16_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("float16")); break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(float), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("float32")); break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int32_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("int32")); break; case ONNX_NAMESPACE::TensorProto_DataType_INT64: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int64_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("int64")); break; case ONNX_NAMESPACE::TensorProto_DataType_UINT32: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint32_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("uint32")); break; case ONNX_NAMESPACE::TensorProto_DataType_UINT64: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint64_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("uint64")); break; default: break;