Skip to content

Commit

Permalink
[WebNN] Support int4 and uint4 data types (#22575)
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry authored Oct 26, 2024
1 parent c547306 commit 008c909
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 4 deletions.
4 changes: 3 additions & 1 deletion js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ export class Tensor implements TensorInterface {
type !== 'uint64' &&
type !== 'int8' &&
type !== 'uint8' &&
type !== 'bool'
type !== 'bool' &&
type !== 'uint4' &&
type !== 'int4'
) {
throw new TypeError(`unsupported type "${type}" to create tensor from MLTensor`);
}
Expand Down
4 changes: 3 additions & 1 deletion js/common/lib/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ export declare namespace Tensor {
| 'uint32'
| 'int64'
| 'uint64'
| 'bool';
| 'bool'
| 'uint4'
| 'int4';

/**
* represent where the tensor data is stored
Expand Down
4 changes: 4 additions & 0 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ const onnxDataTypeToWebnnDataType = new Map<DataType, MLOperandDataType>([
[DataType.uint32, 'uint32'],
[DataType.int64, 'int64'],
[DataType.uint64, 'uint64'],
[DataType.int4, 'int4'],
[DataType.uint4, 'uint4'],
[DataType.int8, 'int8'],
[DataType.uint8, 'uint8'],
[DataType.bool, 'uint8'],
Expand Down Expand Up @@ -214,6 +216,8 @@ export class WebNNBackend {
case 'int8':
bufferView = new Int8Array(buffer);
break;
case 'int4':
case 'uint4':
case 'uint8':
bufferView = new Uint8Array(buffer);
break;
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webnn/webnn.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ interface MLContext {
}
interface MLGraph {}
type MLInputOperandLayout = 'nchw'|'nhwc';
type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8';
type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'|'int4'|'uint4';
interface MLOperandDescriptor {
dataType: MLOperandDataType;
shape?: readonly number[];
Expand Down
4 changes: 3 additions & 1 deletion js/web/lib/wasm/wasm-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ export const isMLTensorSupportedType = (type: Tensor.Type): type is Tensor.MLTen
type === 'uint64' ||
type === 'int8' ||
type === 'uint8' ||
type === 'bool';
type === 'bool' ||
type === 'uint4' ||
type === 'int4';

/**
* Map string data location to integer value
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,

bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) {
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
desc.set("dataType", emscripten::val("int4"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
desc.set("dataType", emscripten::val("uint4"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
desc.set("dataType", emscripten::val("uint8"));
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_typ
}

static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> onnx_to_webnn_data_type_map = {
{ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"},
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
{ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
std::string operand_type;
switch (to_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
operand_type = "int4";
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
operand_type = "uint4";
break;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
operand_type = "uint8";
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/webnn/builders/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap<std::string, Onn
emscripten::val view = emscripten::val::undefined();
switch (tensor.tensor_info.data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint8_t*>(tensor.buffer))};
Expand Down Expand Up @@ -93,6 +95,8 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap<std::string, Onn
emscripten::val view = emscripten::val::undefined();
switch (tensor.tensor_info.data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint8_t*>(tensor.buffer))};
Expand Down Expand Up @@ -210,6 +214,8 @@ void Model::AllocateInputOutputBuffers() {
const auto data_type = input_info.data_type;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
wnn_inputs_.set(input, emscripten::val::global("Uint8Array").new_(num_elements));
break;
Expand Down Expand Up @@ -245,6 +251,8 @@ void Model::AllocateInputOutputBuffers() {
const auto data_type = output_info.data_type;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
wnn_outputs_.set(output, emscripten::val::global("Uint8Array").new_(num_elements));
break;
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,16 @@ Status ModelBuilder::RegisterInitializers() {
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor));
tensor_ptr = reinterpret_cast<std::byte*>(unpacked_tensor.data());
}
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4 ||
data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
// For WebNN int4 and uint4 tensors are stored in Uint8Array,
// so we need to adjust the number of elements.
num_elements = (static_cast<size_t>(num_elements) + 1) / 2;
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint8_t*>(tensor_ptr))};
Expand Down Expand Up @@ -392,6 +400,8 @@ const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type) {

switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
zero_buffer = emscripten::val::global("Uint8Array").new_(1);
break;
Expand Down

0 comments on commit 008c909

Please sign in to comment.