Skip to content

Commit

Permalink
Enable QLinearMatMul for opset21 (#22488)
Browse files Browse the repository at this point in the history
### Description
Enable QLinearMatMul for opset21
  • Loading branch information
HectorSVC authored Oct 22, 2024
1 parent 62f99d8 commit fc2be09
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 31 deletions.
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ Do not modify directly.*
|||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[7, 11]|**T** = tensor(double), tensor(float)|
|QLinearConv|*in* x:**T1**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T1**<br> *in* w:**T2**<br> *in* w_scale:**tensor(float)**<br> *in* w_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *in* B:**T4**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)<br/> **T4** = tensor(int32)|
|QLinearMatMul|*in* a:**T1**<br> *in* a_scale:**TS**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**TS**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**TS**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**<br><br>or<br><br>*in* a:**T1**<br> *in* a_scale:**tensor(float)**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**tensor(float)**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)|
|QLinearMatMul|*in* a:**T1**<br> *in* a_scale:**TS**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**TS**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**TS**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**<br><br>or<br><br>*in* a:**T1**<br> *in* a_scale:**tensor(float)**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**tensor(float)**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**|21+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)<br/> **TS** = tensor(float)|
|||[10, 20]|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|||[19, 20]|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)|
|||[13, 18]|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
Expand Down
20 changes: 14 additions & 6 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
QuantizeLinear);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t,
QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QLinearMatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20, uint8_t,
QLinearMatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20, int8_t,
QLinearMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger);
Expand Down Expand Up @@ -1103,6 +1105,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Int4x2, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t, QLinearMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, QLinearMatMul);
#if !defined(DISABLE_FLOAT8_TYPES)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FNUZ, DequantizeLinear);
Expand Down Expand Up @@ -1686,10 +1690,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
uint8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12,
int8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t,
QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t,
QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20,
uint8_t, QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20,
int8_t, QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t,
MatMulInteger)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t,
Expand Down Expand Up @@ -2764,6 +2768,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2,
DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t,
QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t,
QLinearMatMul)>,
#if !defined(DISABLE_FLOAT8_TYPES)
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN,
DequantizeLinear)>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

namespace onnxruntime {
// uint8_t kernel supports weight being either uint8_t or int8_t
ONNX_OPERATOR_TYPED_KERNEL_EX(
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(
QLinearMatMul,
kOnnxDomain,
10,
20,
uint8_t,
kCpuExecutionProvider,
KernelDefBuilder()
Expand All @@ -26,21 +27,45 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
.TypeConstraint("T3", DataTypeImpl::GetTensorType<uint8_t>()),
QLinearMatMul);

ONNX_OPERATOR_TYPED_KERNEL_EX(
QLinearMatMul,
kOnnxDomain,
21,
uint8_t,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("TS", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(), DataTypeImpl::GetTensorType<int8_t>()})
.TypeConstraint("T3", DataTypeImpl::GetTensorType<uint8_t>()),
QLinearMatMul);

// int8_t kernel only supports weight being int8_t
#define REGISTER_QLINEARMATMUL_INT8_KERNEL() \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
QLinearMatMul, \
kOnnxDomain, \
10, \
int8_t, \
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int8_t>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int8_t>()) \
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int8_t>()), \
QLinearMatMul);

REGISTER_QLINEARMATMUL_INT8_KERNEL();
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(
QLinearMatMul,
kOnnxDomain,
10,
20,
int8_t,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int8_t>()),
QLinearMatMul);

ONNX_OPERATOR_TYPED_KERNEL_EX(
QLinearMatMul,
kOnnxDomain,
21,
int8_t,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("TS", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int8_t>()),
QLinearMatMul);

Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
const auto* a = ctx->Input<Tensor>(IN_A);
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/test/onnx/TestCase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,13 @@ std::unique_ptr<std::set<BrokenTest>> GetBrokenTests(const std::string& provider
{"dequantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}},
{"dequantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}},
{"quantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}},
{"quantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}});
{"quantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}},
{"qlinearmatmul_2D_int8_float16", "fp16 type ont supported by CPU EP", {}},
{"qlinearmatmul_2D_int8_float32", "result diff", {}},
{"qlinearmatmul_2D_uint8_float16", "fp16 type ont supported by CPU EP", {}},
{"qlinearmatmul_3D_int8_float16", "fp16 type ont supported by CPU EP", {}},
{"qlinearmatmul_3D_int8_float32", "result diff", {}},
{"qlinearmatmul_3D_uint8_float16", "fp16 type ont supported by CPU EP", {}}});

// Some EPs may fail to pass some specific testcases.
// For example TenosrRT EP may fail on FLOAT16 related testcases if GPU doesn't support float16.
Expand Down
20 changes: 12 additions & 8 deletions onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul3D_S8S8) {
}

TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8U8) {
auto run_test = [](bool only_t1_not_initializer) {
OpTester test("QLinearMatMul", 10);
auto run_test = [](bool only_t1_not_initializer, int opset_version) {
OpTester test("QLinearMatMul", opset_version);
test.AddInput<uint8_t>("T1", {2, 4},
{208, 236, 0, 238,
3, 214, 255, 29});
Expand Down Expand Up @@ -155,10 +155,12 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8U8) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
};

run_test(false);
run_test(false, 10);
run_test(false, 21);

// NNAPI will require all inputs except T1 to be initializers
run_test(true);
run_test(true, 10);
run_test(true, 21);
}

TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8S8) {
Expand Down Expand Up @@ -197,8 +199,8 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8S8) {
}

TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_S8S8) {
auto run_test = [](bool only_t1_not_initializer) {
OpTester test("QLinearMatMul", 10);
auto run_test = [](bool only_t1_not_initializer, int opset_version) {
OpTester test("QLinearMatMul", opset_version);
test.AddInput<int8_t>("T1", {2, 4},
{80, -2, -128, 110,
-125, 86, 127, -99});
Expand All @@ -225,10 +227,12 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_S8S8) {
test.Run();
};

run_test(false);
run_test(false, 10);
run_test(false, 21);

// NNAPI will require all inputs except T1 to be initializers
run_test(true);
run_test(true, 10);
run_test(true, 21);
}

static void QLinearMatMul2DTest(bool only_t1_not_initializer) {
Expand Down

0 comments on commit fc2be09

Please sign in to comment.