Skip to content

Commit

Permalink
[WebNN EP] Add quantize Ops (#18011)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

Add four quantize Ops: MatmulInteger, ConvInteger, DynamicQuantizeLinear
and DequantizeLinear.
Add datatype TensorProto_DataType_INT8 and TensorProto_DataType_UINT8.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Support quantized models.
  • Loading branch information
zesongw authored Jan 12, 2024
1 parent acba63c commit e1db44b
Show file tree
Hide file tree
Showing 13 changed files with 232 additions and 4 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) {
// TODO: Remove legacy "type" once all browsers implement the new "dataType".
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
desc.set("type", emscripten::val("uint8"));
desc.set("dataType", emscripten::val("uint8"));
return true;
Expand Down
10 changes: 9 additions & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va
}
switch (tensor.data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
scalar = emscripten::val{*reinterpret_cast<uint8_t*>(unpacked_tensor.data())};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
Expand Down Expand Up @@ -148,9 +150,12 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"Clip", {"clamp", true}},
{"Concat", {"concat", true}},
{"Conv", {"conv2d", true}},
{"ConvInteger", {"conv2dInteger", false}},
{"ConvTranspose", {"convTranspose2d", true}},
{"Cos", {"cos", false}},
{"Div", {"div", true}},
{"DequantizeLinear", {"dequantizeLinear", false}},
{"DynamicQuantizeLinear", {"dynamicQuantizeLinear", false}},
{"Elu", {"elu", true}},
{"Equal", {"equal", false}},
{"Erf", {"erf", false}},
Expand All @@ -176,6 +181,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"Log", {"log", false}},
{"LpPool", {"l2Pool2d", false}},
{"MatMul", {"matmul", false}},
{"MatMulInteger", {"matmulInteger", false}},
{"Max", {"max", true}},
{"MaxPool", {"maxPool2d", true}},
{"Min", {"min", true}},
Expand Down Expand Up @@ -242,8 +248,10 @@ constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 1> supported_cpu_data
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
};

constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 7> supported_gpu_data_types = {
constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 9> supported_gpu_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
std::string operand_type;
switch (to_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
operand_type = "uint8";
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
Expand Down
29 changes: 26 additions & 3 deletions onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,

size_t element_size{0};
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
element_size = sizeof(uint8_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
element_size = sizeof(uint16_t);
break;
Expand Down Expand Up @@ -257,7 +262,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
const auto& weight_name = input_defs[1]->Name();
emscripten::val options = emscripten::val::object();
ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger));
if (op_type == "Conv") {
if (op_type == "Conv" || op_type == "ConvInteger") {
int groups = options["groups"].as<int>();
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
Expand All @@ -271,9 +276,26 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
options.set("filterLayout", emscripten::val("ihwo"));
}
}
emscripten::val filter = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val filter = model_builder.GetOperand(weight_name);
if (op_type == "Conv") {
output = model_builder.GetBuilder().call<emscripten::val>("conv2d", input, filter, options);
} else {
emscripten::val x_zero_point = emscripten::val::null();
emscripten::val w_zero_point = emscripten::val::null();
if (input_defs.size() >= 3) {
x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
} else {
x_zero_point = model_builder.GetZeroConstant("uint8");
}
if (input_defs.size() >= 4) {
w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name());
} else {
w_zero_point = model_builder.GetZeroConstant("uint8");
}
output = model_builder.GetBuilder().call<emscripten::val>("conv2dInteger",
input, x_zero_point, filter, w_zero_point, options);
}

output = model_builder.GetBuilder().call<emscripten::val>("conv2d", input, filter, options);
} else {
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("inputLayout", emscripten::val("nhwc"));
Expand Down Expand Up @@ -341,6 +363,7 @@ void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
static std::vector<std::string> op_types =
{
"Conv",
"ConvInteger",
"ConvTranspose",
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/safeint.h"
#include "core/optimizer/initializer.h"
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "core/providers/webnn/builders/impl/base_op_builder.h"

namespace onnxruntime {
namespace webnn {

class DequantizeLinearOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
};

Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val zero_point = emscripten::val::null();
if (input_defs.size() == 3) {
zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
} else {
zero_point = model_builder.GetZeroConstant("uint8");
}
emscripten::val output;
std::vector<int64_t> input_shape;
std::vector<int64_t> scale_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape");
NodeAttrHelper helper(node);
int32_t axis = helper.Get("axis", 1);
// axis is valid for input shape greater than 1D.
if (input_shape.size() > 1) {
axis = static_cast<int32_t>(HandleNegativeAxis(axis, input_shape.size()));
}
// Insert ones before and after the axis dimension for broadcasting of 1D scale tensor.
if (1 == scale_shape.size() && 1 < input_shape.size()) {
std::vector<int32_t> target_shape{static_cast<int>(input_shape[axis])};
target_shape.insert(target_shape.begin(), axis, 1);
target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1);
scale = model_builder.GetBuilder().call<emscripten::val>("reshape", scale, emscripten::val::array(target_shape));
zero_point = model_builder.GetBuilder().call<emscripten::val>("reshape",
zero_point, emscripten::val::array(target_shape));
}
output = model_builder.GetBuilder().call<emscripten::val>("dequantizeLinear", input, scale, zero_point);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));

return Status::OK();
}

void CreateDequantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<DequantizeLinearOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/safeint.h"
#include "core/optimizer/initializer.h"
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "core/providers/webnn/builders/impl/base_op_builder.h"

namespace onnxruntime {
namespace webnn {

class DynamicQuantizaLinearOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
};

Status DynamicQuantizaLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val output_array;
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
emscripten::val options = emscripten::val::object();

output_array = model_builder.GetBuilder().call<emscripten::val>("dynamicQuantizeLinear", input);

for (size_t i = 0, count = output_array["length"].as<size_t>(); i < count; i++) {
model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i]));
}
return Status::OK();
}

void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<DynamicQuantizaLinearOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace webnn
} // namespace onnxruntime
15 changes: 15 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
emscripten::val output = emscripten::val::object();
if (op_type == "MatMul") {
output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b);
} else if (op_type == "MatMulInteger") {
emscripten::val a_zero_point = emscripten::val::null();
emscripten::val b_zero_point = emscripten::val::null();
if (input_defs.size() >= 3) {
a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
} else {
a_zero_point = model_builder.GetZeroConstant("uint8");
}
if (input_defs.size() >= 4) {
b_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name());
} else {
b_zero_point = model_builder.GetZeroConstant("uint8");
}
output = model_builder.GetBuilder().call<emscripten::val>("matmulInteger", a, a_zero_point, b, b_zero_point);
} else { // Gemm
emscripten::val options = emscripten::val::object();
NodeAttrHelper helper(node);
Expand Down Expand Up @@ -149,6 +163,7 @@ void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
{
"Gemm",
"MatMul",
"MatMulInteger",
};

op_registrations.builders.push_back(std::make_unique<GemmOpBuilder>());
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/webnn/builders/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
emscripten::val view = emscripten::val::undefined();
switch (tensor.tensor_info.data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint8_t*>(tensor.buffer))};
break;
Expand Down Expand Up @@ -88,6 +90,8 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
emscripten::val view = emscripten::val::undefined();
switch (tensor.tensor_info.data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint8_t*>(tensor.buffer))};
break;
Expand Down Expand Up @@ -164,6 +168,8 @@ void Model::AllocateInputOutputBuffers() {
const auto data_type = input_info.data_type;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
wnn_inputs_.set(input, emscripten::val::global("Uint8Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
Expand Down Expand Up @@ -195,6 +201,8 @@ void Model::AllocateInputOutputBuffers() {
const auto data_type = output_info.data_type;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
wnn_outputs_.set(output, emscripten::val::global("Uint8Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
Expand Down
39 changes: 39 additions & 0 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"

#include <utility>

namespace onnxruntime {
namespace webnn {

Expand Down Expand Up @@ -158,6 +160,9 @@ Status ModelBuilder::RegisterInitializers() {
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
desc.set("type", emscripten::val("uint8"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint8_t*>(tensor_ptr))};
break;
Expand Down Expand Up @@ -313,6 +318,8 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t),
reinterpret_cast<const uint8_t*>(dest))};
break;
Expand Down Expand Up @@ -439,6 +446,38 @@ void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& op
wnn_operands_.insert(std::make_pair(name, operand));
}

// Get the zero scalar constant.
// Workaround for builer.constant(value, type) method since it has not been implemented now.
// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-value-type
// BTW, the spec is discussing if the builer.constant(value, type) should be dropped at
// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision.
const emscripten::val& ModelBuilder::GetZeroConstant(const std::string& data_type) {
std::string name = "webnn_zero_constant_" + data_type;
// If the operand does not exist, create it.
if (wnn_operands_.find(name) == wnn_operands_.end()) {
emscripten::val desc = emscripten::val::object();
emscripten::val dims = emscripten::val::array();
desc.set("dimensions", dims);
emscripten::val zero_buffer = emscripten::val::undefined();
if (data_type == "uint8") {
if (!SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_UINT8)) {
ORT_THROW("Unsupported data type: " + data_type);
}
zero_buffer = emscripten::val::global("Uint8Array").new_(1);
} else if (data_type == "float32") {
if (!SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
ORT_THROW("Unsupported data type: " + data_type);
}
zero_buffer = emscripten::val::global("Float32Array").new_(1);
} else {
ORT_THROW("Unsupported data type: " + data_type);
}
emscripten::val zero_constant = wnn_builder_.call<emscripten::val>("constant", desc, zero_buffer);
wnn_operands_.insert(std::make_pair(name, zero_constant));
}
return wnn_operands_.at(name);
}

void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) {
skipped_initializers_.insert(tensor_name);
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ModelBuilder {
const emscripten::val& GetContext() const { return wnn_context_; }
const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); }
void AddOperand(const std::string& name, const emscripten::val& operand);
const emscripten::val& GetZeroConstant(const std::string& data_type);
// Use the buffers to persist WebNN allocated data like transposed weight.
// It ensures the validity during inference session.
std::vector<std::unique_ptr<uint8_t[]>> mem_persist_buffers_;
Expand Down
Loading

0 comments on commit e1db44b

Please sign in to comment.