Skip to content

Commit

Permalink
[WebNN EP] Add a duplicate entry to support new "dataType"
Browse files Browse the repository at this point in the history
WebNN spec renames "type" as "dataType" at
webmachinelearning/webnn#464,
add a duplicate entry for "dataType" in order to workaround
the compatibility issue.
  • Loading branch information
Honry committed Oct 9, 2023
1 parent faf9a0f commit d566f96
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 44 deletions.
38 changes: 38 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,43 @@ bool IsValidMultidirectionalBroadcast(std::vector<int64_t>& shape_a,
return true;
}

bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) {
// WebNN changed the name of the MLOperandDescriptor's data type from "type" to "dataType",
// use a duplicate entry temporarily to workaround this API breaking issue.
// TODO: Remove lagecy "type" once all browsers implement the new "dataType".

Check warning on line 148 in onnxruntime/core/providers/webnn/builders/helper.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webnn/builders/helper.cc:148: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
desc.set("type", emscripten::val("uint8"));
desc.set("dataType", emscripten::val("uint8"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
desc.set("type", emscripten::val("float16"));
desc.set("dataType", emscripten::val("float16"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
desc.set("type", emscripten::val("float32"));
desc.set("dataType", emscripten::val("float32"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
desc.set("type", emscripten::val("int32"));
desc.set("dataType", emscripten::val("int32"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
desc.set("type", emscripten::val("int64"));
desc.set("dataType", emscripten::val("int64"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
desc.set("type", emscripten::val("uint32"));
desc.set("dataType", emscripten::val("uint32"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
desc.set("type", emscripten::val("uint64"));
desc.set("dataType", emscripten::val("uint64"));
return true;
default:
return false;
}
}

} // namespace webnn
} // namespace onnxruntime
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,8 @@ bool IsSupportedDataType(const int32_t data_type, const WebnnDeviceType device_t
bool IsValidMultidirectionalBroadcast(std::vector<int64_t>& shape_a,
std::vector<int64_t>& shape_b,
const logging::Logger& logger);

bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type);

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const auto rank = static_cast<int32_t>(input_shape.size());

emscripten::val desc = emscripten::val::object();
desc.set("type", emscripten::val("int64"));
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_INT64), "Unsupported data type");
emscripten::val dims = emscripten::val::array();
dims.call<void>("push", rank);
desc.set("dimensions", dims);
Expand Down
46 changes: 3 additions & 43 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,44 +122,38 @@ Status ModelBuilder::RegisterInitializers() {
auto data_type = tensor.data_type();
emscripten::val operand = emscripten::val::object();
if (IsSupportedDataType(data_type, wnn_device_type_)) {
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
unpacked_tensors_.push_back({});
std::vector<uint8_t>& unpacked_tensor = unpacked_tensors_.back();
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor));
auto num_elements = SafeInt<size_t>(Product(tensor.dims()));
emscripten::val view = emscripten::val::undefined();
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
desc.set("type", emscripten::val("uint8"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint8_t*>(unpacked_tensor.data()))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
desc.set("type", emscripten::val("float16"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint16_t*>(unpacked_tensor.data()))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
desc.set("type", emscripten::val("float32"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<float*>(unpacked_tensor.data()))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
desc.set("type", emscripten::val("int32"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int32_t*>(unpacked_tensor.data()))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
desc.set("type", emscripten::val("int64"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int64_t*>(unpacked_tensor.data()))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
desc.set("type", emscripten::val("uint32"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint32_t*>(unpacked_tensor.data()))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
desc.set("type", emscripten::val("uint64"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint64_t*>(unpacked_tensor.data()))};
break;
Expand Down Expand Up @@ -238,35 +232,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
}

data_type = type_proto->tensor_type().elem_type();
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
desc.set("type", emscripten::val("uint8"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
desc.set("type", emscripten::val("float16"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
desc.set("type", emscripten::val("float32"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
desc.set("type", emscripten::val("int32"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
desc.set("type", emscripten::val("int64"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
desc.set("type", emscripten::val("uint32"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
desc.set("type", emscripten::val("uint64"));
break;
default: {
// TODO: support other type.
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The ", input_output_type, " of graph doesn't have valid type, name: ", name,
" type: ", type_proto->tensor_type().elem_type());
}
}
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
}

if (is_input) {
Expand Down Expand Up @@ -316,41 +282,35 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
memcpy(dest, buffer, size);
emscripten::val view = emscripten::val::undefined();
emscripten::val desc = emscripten::val::object();
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t),
reinterpret_cast<const uint8_t*>(dest))};
desc.set("type", emscripten::val("uint8"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint16_t),
reinterpret_cast<const uint16_t*>(dest))};
desc.set("type", emscripten::val("float16"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(float),
reinterpret_cast<const float*>(dest))};
desc.set("type", emscripten::val("float32"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int32_t),
reinterpret_cast<const int32_t*>(dest))};
desc.set("type", emscripten::val("int32"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int64_t),
reinterpret_cast<const int64_t*>(dest))};
desc.set("type", emscripten::val("int64"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint32_t),
reinterpret_cast<const uint32_t*>(dest))};
desc.set("type", emscripten::val("uint32"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint64_t),
reinterpret_cast<const uint64_t*>(dest))};
desc.set("type", emscripten::val("uint64"));
break;
default:
break;
Expand Down

0 comments on commit d566f96

Please sign in to comment.