diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index c0e1582c17de5..8feb8d7205fa1 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -179,7 +179,9 @@ export class Tensor implements TensorInterface { type !== 'uint64' && type !== 'int8' && type !== 'uint8' && - type !== 'bool' + type !== 'bool' && + type !== 'uint4' && + type !== 'int4' ) { throw new TypeError(`unsupported type "${type}" to create tensor from MLTensor`); } diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 17e2f4d37c91f..af918705b97e3 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -167,7 +167,9 @@ export declare namespace Tensor { | 'uint32' | 'int64' | 'uint64' - | 'bool'; + | 'bool' + | 'uint4' + | 'int4'; /** * represent where the tensor data is stored diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 37eb0e0edc67c..47304fdc64ae4 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -25,6 +25,8 @@ const onnxDataTypeToWebnnDataType = new Map([ [DataType.uint32, 'uint32'], [DataType.int64, 'int64'], [DataType.uint64, 'uint64'], + [DataType.int4, 'int4'], + [DataType.uint4, 'uint4'], [DataType.int8, 'int8'], [DataType.uint8, 'uint8'], [DataType.bool, 'uint8'], @@ -214,6 +216,8 @@ export class WebNNBackend { case 'int8': bufferView = new Int8Array(buffer); break; + case 'int4': + case 'uint4': case 'uint8': bufferView = new Uint8Array(buffer); break; diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index a2d4e9af23e44..2620168738dac 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -28,7 +28,7 @@ interface MLContext { } interface MLGraph {} type MLInputOperandLayout = 'nchw'|'nhwc'; -type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'; +type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'|'int4'|'uint4'; interface MLOperandDescriptor { dataType: MLOperandDataType; shape?: readonly number[]; diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index ad2ff62587252..54071866be5c3 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -252,7 +252,9 @@ export const isMLTensorSupportedType = (type: Tensor.Type): type is Tensor.MLTen type === 'uint64' || type === 'int8' || type === 'uint8' || - type === 'bool'; + type === 'bool' || + type === 'uint4' || + type === 'int4'; /** * Map string data location to integer value diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index dc488f0409418..4b39e03ffc788 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -229,6 +229,12 @@ bool GetBidirectionalBroadcastShape(std::vector& shape_a, bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) { switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + desc.set("dataType", emscripten::val("int4")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: + desc.set("dataType", emscripten::val("uint4")); + return true; case ONNX_NAMESPACE::TensorProto_DataType_BOOL: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: desc.set("dataType", emscripten::val("uint8")); diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 6d2e7533750be..aa3613551d8e1 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -303,6 +303,8 @@ inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_typ } static const InlinedHashMap onnx_to_webnn_data_type_map = { + {ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"}, {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index 3c4fc822f3d01..70ebe18c85b86 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -38,6 +38,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_FLOAT); std::string operand_type; switch (to_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + operand_type = "int4"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: + operand_type = "uint4"; + break; case ONNX_NAMESPACE::TensorProto_DataType_BOOL: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: operand_type = "uint8"; diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index fcfdb146bff34..231b65a4d1894 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -42,6 +42,8 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap(tensor.buffer))}; @@ -93,6 +95,8 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap(tensor.buffer))}; @@ -210,6 +214,8 @@ void Model::AllocateInputOutputBuffers() { const auto data_type = input_info.data_type; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: wnn_inputs_.set(input, emscripten::val::global("Uint8Array").new_(num_elements)); break; @@ -245,6 +251,8 @@ void Model::AllocateInputOutputBuffers() { const auto data_type = output_info.data_type; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: wnn_outputs_.set(output, emscripten::val::global("Uint8Array").new_(num_elements)); break; diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index ccf6c7911638b..84f8cc4b14665 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -137,8 +137,16 @@ Status ModelBuilder::RegisterInitializers() { ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); tensor_ptr = reinterpret_cast(unpacked_tensor.data()); } + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4 || + data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + // For WebNN int4 and uint4 tensors are stored in Uint8Array, + // so we need to adjust the number of elements. + num_elements = (static_cast(num_elements) + 1) / 2; + } switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(tensor_ptr))}; @@ -392,6 +400,8 @@ const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type) { switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: zero_buffer = emscripten::val::global("Uint8Array").new_(1); break;