diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 5bd1a89c0dea1..95dc8c3cde46c 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1351,8 +1351,8 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T1 : tensor(int8), tensor(uint8), tensor(int32)
-- Constrain 'x' and 'x_zero_point' to 8-bit integer tensors or 32-bit signed integer tensors.
+- T1 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32)
+- Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, 16-bit integer tensors, or 32-bit signed integer tensors.
- T2 : tensor(float16), tensor(float)
- Constrain 'y', 'x_scale' to float tensors.
@@ -4194,8 +4194,9 @@ This version of the operator has been available since version 1 of the 'com.micr
### **com.microsoft.QuantizeLinear**
The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor.
- The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8.
- For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
+ The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, [-128, 127] if it's int8,
+ [0, 65,535] if it's uint16, and [-32,768, 32,767] if it's int16. For (x / y_scale), it's rounding to nearest ties to even.
+ Refer to https://en.wikipedia.org/wiki/Rounding for details.
Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').
#### Version
@@ -4232,8 +4233,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- T1 : tensor(float16), tensor(float)
- Constrain 'x', 'y_scale' to float tensors.
-- T2 : tensor(int8), tensor(uint8)
-- Constrain 'y_zero_point' and 'y' to 8-bit integer tensors.
+- T2 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16)
+- Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index d46f3ed9bd262..33c187a28b62e 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -439,7 +439,7 @@ Do not modify directly.*
|CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)|
-|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float)|
+|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)
**T2** = tensor(float)|
|DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
|EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float)|
@@ -472,7 +472,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
-|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
+|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 660c8bd9e0624..0ec5088808656 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -56,9 +56,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLine
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid);
@@ -191,9 +195,13 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc b/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc
deleted file mode 100644
index 28a304bfc7f0e..0000000000000
--- a/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc
+++ /dev/null
@@ -1,56 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "core/providers/cpu/quantization/quantize_linear.h"
-#include "core/providers/common.h"
-
-namespace onnxruntime {
-namespace contrib {
-
-ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
- DequantizeLinear,
- 1,
- uint8_t,
- KernelDefBuilder()
- .TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
- DequantizeLinear);
-
-ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
- DequantizeLinear,
- 1,
- int8_t,
- KernelDefBuilder()
- .TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
- DequantizeLinear);
-
-ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
- DequantizeLinear,
- 1,
- int32_t,
- KernelDefBuilder()
- .TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
- DequantizeLinear);
-
-ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
- QuantizeLinear,
- 1,
- uint8_t,
- KernelDefBuilder()
- .TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
- QuantizeLinear);
-
-ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
- QuantizeLinear,
- 1,
- int8_t,
- KernelDefBuilder()
- .TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
- QuantizeLinear);
-
-} // namespace contrib
-} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
index aa2ad9f1ff6b1..4313fae767fe5 100644
--- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
@@ -136,8 +136,9 @@ Performs element-wise binary {name} on 8 bit data types (with Numpy-style broadc
static const char* QuantizeLinear_ver1_doc = R"DOC(
The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor.
-The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8.
-For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
+The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, [-128, 127] if it's int8,
+[0, 65,535] if it's uint16, and [-32,768, 32,767] if it's int16. For (x / y_scale), it's rounding to nearest ties to even.
+Refer to https://en.wikipedia.org/wiki/Rounding for details.
Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC";
ONNX_MS_OPERATOR_SET_SCHEMA(
@@ -161,8 +162,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"T2", OpSchema::Optional)
.Output(0, "y", "N-D quantized output tensor. It has same shape as input 'x'.", "T2")
.TypeConstraint("T1", {"tensor(float16)", "tensor(float)"}, "Constrain 'x', 'y_scale' to float tensors.")
- .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"},
- "Constrain 'y_zero_point' and 'y' to 8-bit integer tensors.")
+ .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", "tensor(uint16)"},
+ "Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.")
.SetDoc(QuantizeLinear_ver1_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
if (ctx.getNumInputs() == 3 && ctx.getInputType(2) != nullptr) {
@@ -202,9 +203,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(DequantizeLinear, 1,
"T1", OpSchema::Optional)
.Output(0, "y", "N-D full precision output tensor. It has same shape as input 'x'.",
"T2")
- .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)", "tensor(int32)"},
- "Constrain 'x' and 'x_zero_point' to 8-bit integer tensors or 32-bit "
- "signed integer tensors.")
+ .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)", "tensor(int16)",
+ "tensor(uint16)", "tensor(int32)"},
+ "Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, "
+ "16-bit integer tensors, or 32-bit signed integer tensors.")
.TypeConstraint("T2", {"tensor(float16)", "tensor(float)"},
"Constrain 'y', 'x_scale' to float tensors.")
.SetDoc(DequantizeLinear_ver1_doc)
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
index f517be185b3fa..b6ac4a1ca1d6c 100644
--- a/onnxruntime/core/mlas/lib/mlasi.h
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -633,6 +633,24 @@ void
int8_t ZeroPoint
);
+typedef
+void
+(MLASCALL MLAS_QUANTIZE_LINEAR_U16_KERNEL)(
+ const float* Input,
+ uint16_t* Output,
+ size_t N,
+ float Scale,
+ uint16_t ZeroPoint);
+
+typedef
+void
+(MLASCALL MLAS_QUANTIZE_LINEAR_S16_KERNEL)(
+ const float* Input,
+ int16_t* Output,
+ size_t N,
+ float Scale,
+ int16_t ZeroPoint);
+
template
struct MLAS_QUANT_KERNEL
{
@@ -749,6 +767,8 @@ extern "C" {
MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8Kernel;
MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8Kernel;
MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8Kernel;
+ MLAS_QUANTIZE_LINEAR_S16_KERNEL MlasQuantizeLinearS16Kernel;
+ MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel;
#if defined(MLAS_TARGET_AMD64)
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3;
@@ -959,6 +979,8 @@ struct MLAS_PLATFORM {
const MLAS_GEMM_QUANT_DISPATCH* GemmU8X8Dispatch;
MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel;
MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel;
+ MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel;
+ MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel;
#endif
#if defined(MLAS_TARGET_AMD64)
MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine;
@@ -986,6 +1008,8 @@ struct MLAS_PLATFORM {
MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL* ReduceMinimumMaximumF32Kernel;
MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel;
MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel;
+ MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel;
+ MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel;
uint32_t NchwcBlockSize;
uint32_t PreferredBufferAlignment;
int32_t MaximumThreadCount;
diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp
index 86b7450a7c4e5..7e2b117d6f249 100644
--- a/onnxruntime/core/mlas/lib/platform.cpp
+++ b/onnxruntime/core/mlas/lib/platform.cpp
@@ -230,6 +230,8 @@ Return Value:
this->QLinearAddU8Kernel = MlasQLinearAddU8Kernel;
this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel;
this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel;
+ this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel;
+ this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel;
this->NchwcBlockSize = 8;
this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT;
@@ -475,6 +477,8 @@ Return Value:
this->GemmDoubleKernel = MlasDgemmKernel;
this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel;
this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel;
+ this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel;
+ this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel;
#if defined(__linux__)
unsigned long hwcap2 = getauxval(AT_HWCAP2);
diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp
index 0d38288c6d42c..830a3a6a492db 100644
--- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp
+++ b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp
@@ -1,3 +1,4 @@
+#include
#include "mlasi.h"
#include
@@ -82,8 +83,15 @@ Return Value:
auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1);
auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3);
- auto CharVector = vec_pack(ShortVector0, ShortVector1);
- vec_xst(CharVector, 0, (int8_t *) Output);
+
+ if constexpr (std::is_same_v || std::is_same_v) {
+ auto CharVector = vec_pack(ShortVector0, ShortVector1);
+ vec_xst(CharVector, 0, Output);
+ } else {
+ static_assert(std::is_same_v || std::is_same_v);
+ vec_xst(ShortVector0, 0, Output);
+ vec_xst(ShortVector1, 0, &Output[8]);
+ }
Output += 16;
Input += 16;
@@ -124,3 +132,30 @@ MlasQuantizeLinearS8Kernel(
{
MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint);
}
+
+void
+MLASCALL
+MlasQuantizeLinearU16Kernel(
+ const float* Input,
+ uint16_t* Output,
+ size_t N,
+ float Scale,
+ uint16_t ZeroPoint
+ )
+{
+ MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint);
+}
+
+void
+MLASCALL
+MlasQuantizeLinearS16Kernel(
+ const float* Input,
+ int16_t* Output,
+ size_t N,
+ float Scale,
+ int16_t ZeroPoint
+ )
+{
+ MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint);
+}
+
diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp
index c6e8af38c0020..133ad79594c55 100644
--- a/onnxruntime/core/mlas/lib/quantize.cpp
+++ b/onnxruntime/core/mlas/lib/quantize.cpp
@@ -21,6 +21,7 @@ Module Name:
#include "mlasi.h"
#if defined(MLAS_NEON64_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS)
+#include
//
// QuantizeLinear implementation using NEON or SSE2 intrinsics.
@@ -79,6 +80,20 @@ MlasQuantizeLinearPackBytes(
MLAS_INT32X4 IntegerVector
);
+template
+void
+MlasQuantizeLinearStore4PackedValues(
+ MLAS_INT32X4 IntegerVector,
+ OutputType* Output
+ );
+
+template
+void
+MlasQuantizeLinearStoreSingleValue(
+ MLAS_INT32X4 IntegerVector,
+ OutputType* Output
+ );
+
#if defined(MLAS_NEON64_INTRINSICS)
template
@@ -100,6 +115,104 @@ MlasQuantizeLinearPackBytes(
return vreinterpretq_s32_u8(ByteVector);
}
+template<>
+MLAS_INT32X4
+MlasQuantizeLinearPackBytes(
+ MLAS_INT32X4 IntegerVector
+ )
+{
+ //
+ // Swizzle the least significant u16 from each int32_t element to the
+ // bottom eight bytes of the vector register.
+ //
+
+ uint16x8_t WordVector = vreinterpretq_u16_s32(IntegerVector);
+ WordVector = vuzp1q_u16(WordVector, WordVector);
+ return vreinterpretq_s32_u16(WordVector);
+}
+
+template<>
+MLAS_INT32X4
+MlasQuantizeLinearPackBytes(
+ MLAS_INT32X4 IntegerVector
+ )
+{
+ //
+ // Swizzle the least significant u16 from each int32_t element to the
+ // bottom eight bytes of the vector register.
+ //
+
+ int16x8_t WordVector = vreinterpretq_s16_s32(IntegerVector);
+ WordVector = vuzp1q_s16(WordVector, WordVector);
+ return vreinterpretq_s32_s16(WordVector);
+}
+
+template
+MLAS_FORCEINLINE
+void
+MlasQuantizeLinearStore4PackedValues(
+ MLAS_INT32X4 IntegerVector,
+ OutputType* Output
+ )
+{
+ // Copies the lower 4 packed elements of the vector into memory (Output).
+
+ if constexpr (std::is_same_v || std::is_same_v) {
+ vst1q_lane_s32(reinterpret_cast(Output), IntegerVector, 0);
+ } else {
+ static_assert(std::is_same_v || std::is_same_v);
+ vst1q_lane_s64(reinterpret_cast(Output), vreinterpretq_s64_s32(IntegerVector), 0);
+ }
+}
+
+template <>
+MLAS_FORCEINLINE
+void
+MlasQuantizeLinearStoreSingleValue(
+ MLAS_INT32X4 IntegerVector,
+ uint8_t* Output
+ )
+{
+ // Copies the lower 8-bit element of the vector into memory (Output).
+ vst1q_lane_u8(Output, vreinterpretq_u8_s32(IntegerVector), 0);
+}
+
+template <>
+MLAS_FORCEINLINE
+void
+MlasQuantizeLinearStoreSingleValue(
+ MLAS_INT32X4 IntegerVector,
+ int8_t* Output
+ )
+{
+ // Copies the lower 8-bit element of the vector into memory (Output).
+ vst1q_lane_s8(Output, vreinterpretq_s8_s32(IntegerVector), 0);
+}
+
+template <>
+MLAS_FORCEINLINE
+void
+MlasQuantizeLinearStoreSingleValue(
+ MLAS_INT32X4 IntegerVector,
+ uint16_t* Output
+ )
+{
+ // Copies the lower 16-bit element of the vector into memory (Output).
+ vst1q_lane_u16(Output, vreinterpretq_u16_s32(IntegerVector), 0);
+}
+
+template <>
+MLAS_FORCEINLINE
+void
+MlasQuantizeLinearStoreSingleValue(
+ MLAS_INT32X4 IntegerVector,
+ int16_t* Output
+ )
+{
+ // Copies the lower 16-bit element of the vector into memory (Output).
+ vst1q_lane_s16(Output, vreinterpretq_s16_s32(IntegerVector), 0);
+}
+
#else
template<>
@@ -128,6 +241,86 @@ MlasQuantizeLinearPackBytes(
return IntegerVector;
}
+template<>
+MLAS_FORCEINLINE
+MLAS_INT32X4
+MlasQuantizeLinearPackBytes(
+ MLAS_INT32X4 IntegerVector
+ )
+{
+#if defined(MLAS_SSE41_INTRINSICS)
+ IntegerVector = _mm_packus_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes.
+#else
+ // Cannot use _mm_packus_epi32 because that was not available until SSE4.1.
+ // Instead, emulate by sign-extending the first 16-bits of each packed 32-bit element.
+ // Afterwards, can use _mm_packs_epi32, which is available on SSE2.
+ // See: https://stackoverflow.com/a/11028244
+
+ IntegerVector = _mm_slli_epi32(IntegerVector, 16);
+ IntegerVector = _mm_srai_epi32(IntegerVector, 16); // Sign-extend: undo left shift with right arithmetic shift
+ IntegerVector = _mm_packs_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes.
+#endif // defined(MLAS_SSE41_INTRINSICS)
+
+ return IntegerVector;
+}
+
+template<>
+MLAS_FORCEINLINE
+MLAS_INT32X4
+MlasQuantizeLinearPackBytes(
+ MLAS_INT32X4 IntegerVector
+ )
+{
+ IntegerVector = _mm_packs_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes.
+
+ return IntegerVector;
+}
+
+template
+MLAS_FORCEINLINE
+void
+MlasQuantizeLinearStore4PackedValues(
+ MLAS_INT32X4 IntegerVector,
+ OutputType* Output
+ )
+{
+ // Copies the lower 4 packed elements of the vector into memory (Output).
+
+ if constexpr (std::is_same_v || std::is_same_v) {
+ *(reinterpret_cast(Output)) = _mm_cvtsi128_si32(IntegerVector);
+ } else {
+ static_assert(std::is_same_v || std::is_same_v);
+
+#if defined(MLAS_TARGET_IX86)
+ // x86 does not support _mm_cvtsi128_si64, so use _mm_maskmoveu_si128 instead.
+ constexpr uint32_t bytes_high_bit = 0x80808080;
+ const __m128i first_8_bytes_mask = _mm_set_epi32(0, 0, bytes_high_bit, bytes_high_bit);
+ _mm_maskmoveu_si128(IntegerVector, first_8_bytes_mask, reinterpret_cast(Output));
+#else
+ *(reinterpret_cast(Output)) = _mm_cvtsi128_si64(IntegerVector);
+#endif // defined(MLAS_TARGET_IX86)
+ }
+}
+
+template
+MLAS_FORCEINLINE
+void
+MlasQuantizeLinearStoreSingleValue(
+ MLAS_INT32X4 IntegerVector,
+ OutputType* Output
+ )
+{
+ static_assert(std::is_same_v ||
+ std::is_same_v ||
+ std::is_same_v ||
+ std::is_same_v);
+
+ // Copies the lower element of the vector into memory (Output).
+ // Expects that the 32-bit element in lane 0 is already within the valid numerical
+ // range of the OutputType.
+ *Output = static_cast(_mm_cvtsi128_si32(IntegerVector));
+}
+
#endif
template
@@ -180,12 +373,7 @@ Return Value:
MinimumValueVector, MaximumValueVector, ZeroPointVector);
IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector);
-
-#if defined(MLAS_NEON64_INTRINSICS)
- vst1q_lane_s32((int32_t*)Output, IntegerVector, 0);
-#else
- *((int32_t*)Output) = _mm_cvtsi128_si32(IntegerVector);
-#endif
+ MlasQuantizeLinearStore4PackedValues(IntegerVector, Output);
Input += 4;
Output += 4;
@@ -202,11 +390,7 @@ Return Value:
auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector,
MinimumValueVector, MaximumValueVector, ZeroPointVector);
-#if defined(MLAS_NEON64_INTRINSICS)
- vst1q_lane_u8((uint8_t*)Output + n, vreinterpretq_u8_s32(IntegerVector), 0);
-#else
- *((uint8_t*)Output + n) = (uint8_t)_mm_cvtsi128_si32(IntegerVector);
-#endif
+ MlasQuantizeLinearStoreSingleValue(IntegerVector, &Output[n]);
}
}
@@ -236,6 +420,32 @@ MlasQuantizeLinearU8Kernel(
MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint);
}
+void
+MLASCALL
+MlasQuantizeLinearU16Kernel(
+ const float* Input,
+ uint16_t* Output,
+ size_t N,
+ float Scale,
+ uint16_t ZeroPoint
+)
+{
+ MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint);
+}
+
+void
+MLASCALL
+MlasQuantizeLinearS16Kernel(
+ const float* Input,
+ int16_t* Output,
+ size_t N,
+ float Scale,
+ int16_t ZeroPoint
+)
+{
+ MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint);
+}
+
template<>
void
MLASCALL
@@ -274,6 +484,44 @@ MlasQuantizeLinear(
Input, Output, N, Scale, ZeroPoint);
}
+template<>
+void
+MLASCALL
+MlasQuantizeLinear(
+ const float* Input,
+ uint16_t* Output,
+ size_t N,
+ float Scale,
+ uint16_t ZeroPoint
+ )
+{
+#if defined(MLAS_TARGET_AMD64)
+ GetMlasPlatform().QuantizeLinearU16Kernel(
+#else
+ MlasQuantizeLinearU16Kernel(
+#endif
+ Input, Output, N, Scale, ZeroPoint);
+}
+
+template<>
+void
+MLASCALL
+MlasQuantizeLinear(
+ const float* Input,
+ int16_t* Output,
+ size_t N,
+ float Scale,
+ int16_t ZeroPoint
+ )
+{
+#if defined(MLAS_TARGET_AMD64)
+ GetMlasPlatform().QuantizeLinearS16Kernel(
+#else
+ MlasQuantizeLinearS16Kernel(
+#endif
+ Input, Output, N, Scale, ZeroPoint);
+}
+
#else
#if defined(MLAS_TARGET_POWER)
@@ -306,6 +554,34 @@ MlasQuantizeLinear(
GetMlasPlatform().QuantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint);
}
+template<>
+void
+MLASCALL
+MlasQuantizeLinear(
+ const float* Input,
+ int16_t* Output,
+ size_t N,
+ float Scale,
+ int16_t ZeroPoint
+ )
+{
+ GetMlasPlatform().QuantizeLinearS16Kernel(Input, Output, N, Scale, ZeroPoint);
+}
+
+template<>
+void
+MLASCALL
+MlasQuantizeLinear(
+ const float* Input,
+ uint16_t* Output,
+ size_t N,
+ float Scale,
+ uint16_t ZeroPoint
+ )
+{
+ GetMlasPlatform().QuantizeLinearU16Kernel(Input, Output, N, Scale, ZeroPoint);
+}
+
#endif
//
@@ -381,6 +657,29 @@ MlasQuantizeLinear(
float Scale,
uint8_t ZeroPoint
);
+
+template
+void
+MLASCALL
+MlasQuantizeLinear(
+ const float* Input,
+ int16_t* Output,
+ size_t N,
+ float Scale,
+ int16_t ZeroPoint
+ );
+
+template
+void
+MLASCALL
+MlasQuantizeLinear(
+ const float* Input,
+ uint16_t* Output,
+ size_t N,
+ float Scale,
+ uint16_t ZeroPoint
+ );
+
#endif
#endif
diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc
index b67f6d6ec0794..624679e7b1b4b 100644
--- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc
+++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc
@@ -1,131 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/double_qdq_pairs_remover.h"
+#include
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
+#include "core/optimizer/qdq_transformer/qdq_util.h"
namespace onnxruntime {
-Status DoubleQDQPairsRemover::ApplyImpl(
- Graph& graph,
- bool& modified,
- int /*graph_level*/,
- const logging::Logger& /*logger*/) const {
- const GraphViewer graph_viewer(graph);
- const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
-
- for (const auto& self_index : node_topology_list) {
- NodeIndex parent_index = 0;
- NodeIndex child_index = 0;
- NodeIndex grandchild_index = 0;
- if (IsNodeRemovable(graph, self_index, parent_index, child_index, grandchild_index)) {
- graph.RemoveEdge(parent_index, self_index, 0, 0);
- graph.RemoveEdge(self_index, child_index, 0, 0);
- graph.RemoveEdge(child_index, grandchild_index, 0, 0);
- graph_utils::ReplaceNodeInput(*graph.GetNode(grandchild_index), 0, *graph.GetNode(self_index)->MutableInputDefs()[0]);
- graph.AddEdge(parent_index, grandchild_index, 0, 0);
- graph.RemoveNode(child_index);
- graph.RemoveNode(self_index);
- modified = true;
- }
- }
- return Status::OK();
-}
-
-bool DoubleQDQPairsRemover::IsNodeRemovable(
- Graph& graph,
- const NodeIndex& self_index,
- NodeIndex& parent_index,
- NodeIndex& child_index,
- NodeIndex& grandchild_index) {
- // Check if the self is a DQ, and have one parent and one child, and cannot be a graph output
- Node* self = graph.GetNode(self_index);
- if (self == nullptr ||
- self->OpType() != "DequantizeLinear" ||
- self->GetInputEdgesCount() != 1 ||
- self->GetOutputEdgesCount() != 1 ||
- self->InputDefs().size() != InputIndex::TOTAL_COUNT ||
- graph.NodeProducesGraphOutput(*self)) {
- return false;
- }
-
- // Type is either "tensor(uint8)" or "tensor(int8)"
- const auto& self_zp_type = *self->InputDefs()[InputIndex::ZERO_POINT_ID]->Type();
- // child should be a Q, and have only one child, have the same type as self, and cannot be a graph output
- child_index = self->OutputEdgesBegin()->GetNode().Index();
- const Node* child = graph.GetNode(child_index);
- if (child == nullptr ||
- child->OpType() != "QuantizeLinear" ||
- child->GetOutputEdgesCount() != 1 ||
- child->InputDefs().size() != InputIndex::TOTAL_COUNT ||
- *child->InputDefs()[InputIndex::ZERO_POINT_ID]->Type() != self_zp_type ||
- graph.NodeProducesGraphOutput(*child)) {
- return false;
- }
-
- // parent should be a Q, and have only one output, and cannot be a graph output
- parent_index = self->InputEdgesBegin()->GetNode().Index();
- Node* parent = graph.GetNode(parent_index);
- if (parent == nullptr ||
- parent->GetOutputEdgesCount() != 1 ||
- parent->OpType() != "QuantizeLinear" ||
- graph.NodeProducesGraphOutput(*parent)) {
- return false;
- }
-
- // grandchild should be a DQ
- grandchild_index = child->OutputEdgesBegin()->GetNode().Index();
- Node* grandchild = graph.GetNode(grandchild_index);
- if (grandchild == nullptr ||
- grandchild->OpType() != "DequantizeLinear") {
- return false;
- }
- const auto get_constant_initializer = [&graph](const std::string& initializer_name) {
- return graph.GetConstantInitializer(initializer_name, true);
- };
- if (!QDQ::IsQDQPairSupported(*parent, *self, get_constant_initializer, graph.ModelPath()) ||
- !QDQ::IsQDQPairSupported(*child, *grandchild, get_constant_initializer, graph.ModelPath())) {
- return false;
- }
- bool skip_reset = false;
- float new_scale = 0.0f;
- if (self_zp_type == "tensor(uint8)") {
- uint8_t new_zero_point = 0;
- if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point, skip_reset)) {
- return false;
- }
- if (skip_reset) {
- return true;
- }
- ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale);
- ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale);
- ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point);
- ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point);
- } else {
- int8_t new_zero_point = 0;
- if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point, skip_reset)) {
- return false;
- }
- if (skip_reset) {
- return true;
- }
- ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale);
- ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale);
- ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point);
- ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point);
- }
- return true;
+// Applies a new zero point or scale as the input for a Q/DQ node.
+template
+static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index, T value) {
+ const auto* input_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[index]->Name());
+ Initializer input_init{*input_tensor, graph.ModelPath()};
+ ONNX_NAMESPACE::TensorProto new_input_tensor(*input_tensor);
+ input_init.data()[0] = value;
+ input_init.ToProto(new_input_tensor);
+ auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name());
+ new_input_tensor.set_name(new_name);
+ NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor);
+ graph_utils::ReplaceNodeInput(node, index, new_input);
}
+// Returns a new zero point and scale value for the given Q/DQ nodes.
template
-bool DoubleQDQPairsRemover::FindNewZeroPointAndScale(const Graph& graph, const Node& node1, const Node& node2,
- float& new_scale, T& new_zero_point, bool& skip_reset) {
+static bool FindNewZeroPointAndScale(const Graph& graph, const Node& node1, const Node& node2,
+ float& new_scale, T& new_zero_point, bool& skip_reset) {
// scale & zero point share same initializer, no need to reset the value
- const std::string& node1_scale_name = node1.InputDefs()[InputIndex::SCALE_ID]->Name();
- const std::string& node2_scale_name = node2.InputDefs()[InputIndex::SCALE_ID]->Name();
- const std::string& node1_zp_name = node1.InputDefs()[InputIndex::ZERO_POINT_ID]->Name();
- const std::string& node2_zp_name = node2.InputDefs()[InputIndex::ZERO_POINT_ID]->Name();
+ const std::string& node1_scale_name = node1.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
+ const std::string& node2_scale_name = node2.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
+ const std::string& node1_zp_name = node1.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
+ const std::string& node2_zp_name = node2.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
skip_reset = false;
if (node1_scale_name == node2_scale_name && node1_zp_name == node2_zp_name) {
skip_reset = true;
@@ -175,16 +81,141 @@ bool DoubleQDQPairsRemover::FindNewZeroPointAndScale(const Graph& graph, const N
return true;
}
-template
-void DoubleQDQPairsRemover::ApplyNewInputValue(Graph& graph, Node& node, const InputIndex& index, T value) {
- const auto* input_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[index]->Name());
- Initializer input_init{*input_tensor, graph.ModelPath()};
- TensorProto new_input_tensor(*input_tensor);
- input_init.data()[0] = value;
- input_init.ToProto(new_input_tensor);
- auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name());
- new_input_tensor.set_name(new_name);
- NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor);
- graph_utils::ReplaceNodeInput(node, index, new_input);
+// Recomputes the zero point and scale of the outer Q/DQ nodes (i.e., Q1 and DQ2). This is necessary because
+// the original two QDQ pairs may have different zero-points and scales. Ex: Q1 -> DQ1 -> Q2 -> DQ2, where
+// the first pair has (zp1, scale1) and the second pair has (zp2, scale2).
+// After removing the middle two nodes, the zero point and scale of the final (outer) ops must be recomputed
+// for correctness.
+template
+static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2, Node& dq2) {
+ bool skip_reset = false;
+ float new_scale = 0.0f;
+ ZeroPointType new_zero_point = 0;
+ if (!FindNewZeroPointAndScale(graph, dq1, q2, new_scale, new_zero_point, skip_reset)) {
+ return false;
+ }
+ if (skip_reset) {
+ return true;
+ }
+ ApplyNewInputValue(graph, dq2, QDQ::InputIndex::SCALE_ID, new_scale);
+ ApplyNewInputValue(graph, q1, QDQ::InputIndex::SCALE_ID, new_scale);
+ ApplyNewInputValue(graph, dq2, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point);
+ ApplyNewInputValue(graph, q1, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point);
+
+ return true;
+}
+
+// Checks if the provided node index (dq1_index) is a part of a valid double QDQ pair sequence
+// (i.e., Q1 -> DQ1 -> Q2 -> DQ2) that can be reduced to the outer Q/DQ nodes (i.e., Q1 -> DQ2).
+// If so, the zero point and scale of the outer Q/DQ nodes are recomputed and the node indices of the other nodes
+// in the sequence (i.e., Q1, Q2, and DQ2) are returned via output parameters.
+static bool IsReducibleDoubleQDQSequence(Graph& graph, NodeIndex& q1_index, NodeIndex dq1_index,
+ NodeIndex& q2_index, NodeIndex& dq2_index) {
+ // Ensure that dq1 is a DQ operator, has one parent and one child, and is not a graph output
+ Node* dq1 = graph.GetNode(dq1_index);
+ if (dq1 == nullptr ||
+ dq1->OpType() != "DequantizeLinear" ||
+ dq1->GetInputEdgesCount() != 1 ||
+ dq1->GetOutputEdgesCount() != 1 ||
+ graph.NodeProducesGraphOutput(*dq1)) {
+ return false;
+ }
+
+ // Ensure that q2 is a Q operator, has only one child, and is not a graph output
+ q2_index = dq1->OutputEdgesBegin()->GetNode().Index();
+ const Node* q2 = graph.GetNode(q2_index);
+ if (q2 == nullptr ||
+ q2->OpType() != "QuantizeLinear" ||
+ q2->GetOutputEdgesCount() != 1 ||
+ graph.NodeProducesGraphOutput(*q2)) {
+ return false;
+ }
+
+ // Ensure that q1 is a Q operator, has only one output, and is not a graph output
+ q1_index = dq1->InputEdgesBegin()->GetNode().Index();
+ Node* q1 = graph.GetNode(q1_index);
+ if (q1 == nullptr ||
+ q1->GetOutputEdgesCount() != 1 ||
+ q1->OpType() != "QuantizeLinear" ||
+ graph.NodeProducesGraphOutput(*q1)) {
+ return false;
+ }
+
+ // Ensure the dq2 is a DQ operator.
+ dq2_index = q2->OutputEdgesBegin()->GetNode().Index();
+ Node* dq2 = graph.GetNode(dq2_index);
+ if (dq2 == nullptr ||
+ dq2->OpType() != "DequantizeLinear") {
+ return false;
+ }
+
+ const auto get_constant_initializer = [&graph](const std::string& initializer_name) {
+ return graph.GetConstantInitializer(initializer_name, true);
+ };
+
+ // Each QDQ pair (i.e., q1 -> dq1, q2 -> dq2) has to meet the following additional requirements:
+ // - Scalar/constant zero-point and scale.
+ // - The DQ and Q ops within a pair must have the same scale and zero-point.
+ // However, each pair is allowed to have different scales and zero-points.
+ //
+ // TODO: IsQDQPairSupported() requires an explicit zero-point input, but technically a default
+ // value of 0 could be fine.
+ if (!QDQ::IsQDQPairSupported(*q1, *dq1, get_constant_initializer, graph.ModelPath()) ||
+ !QDQ::IsQDQPairSupported(*q2, *dq2, get_constant_initializer, graph.ModelPath())) {
+ return false;
+ }
+
+ const auto& dq1_input_defs = dq1->InputDefs();
+ const ONNX_NAMESPACE::TensorProto* dq1_zp_tensor_proto = graph.GetConstantInitializer(
+ dq1_input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Name(), true);
+
+ assert(dq1_zp_tensor_proto != nullptr); // IsQDQPairSupported should have checked that this exists.
+
+ auto dq1_zp_type = dq1_zp_tensor_proto->data_type();
+
+ if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
+ return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2);
+ }
+
+ if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
+ return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2);
+ }
+
+ if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
+ return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2);
+ }
+
+ if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) {
+ return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2);
+ }
+
+ return false; // Unsupported zero-point type
+}
+
+Status DoubleQDQPairsRemover::ApplyImpl(
+ Graph& graph,
+ bool& modified,
+ int /*graph_level*/,
+ const logging::Logger& /*logger*/) const {
+ const GraphViewer graph_viewer(graph);
+ const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
+
+ for (const auto& dq1_index : node_topology_list) {
+ NodeIndex q1_index = 0;
+ NodeIndex q2_index = 0;
+ NodeIndex dq2_index = 0;
+ if (IsReducibleDoubleQDQSequence(graph, q1_index, dq1_index, q2_index, dq2_index)) {
+ graph.RemoveEdge(q1_index, dq1_index, 0, 0);
+ graph.RemoveEdge(dq1_index, q2_index, 0, 0);
+ graph.RemoveEdge(q2_index, dq2_index, 0, 0);
+ graph_utils::ReplaceNodeInput(*graph.GetNode(dq2_index), 0, *graph.GetNode(dq1_index)->MutableInputDefs()[0]);
+ graph.AddEdge(q1_index, dq2_index, 0, 0);
+ graph.RemoveNode(q2_index);
+ graph.RemoveNode(dq1_index);
+ modified = true;
+ }
+ }
+ return Status::OK();
}
+
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.h b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h
index c016f7181b7fe..1833b007674fd 100644
--- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.h
+++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h
@@ -3,19 +3,16 @@
#pragma once
-#include "core/common/common.h"
#include "core/optimizer/graph_transformer.h"
-#include "core/optimizer/qdq_transformer/qdq_util.h"
namespace onnxruntime {
-using ONNX_NAMESPACE::TensorProto;
-using ONNX_NAMESPACE::TensorProto_DataType;
-using QDQ::InputIndex;
-
/**
* @Class DoubleQDQPairsRemover
* @brief Remove one pair of Q-DQ from Double Q-DQ pairs.
+ * Specifically, this transformer converts the sequence Q1 -> DQ1 -> Q2 -> DQ2, where the first pair has (zp1, scale1)
+ * and the second pair has (zp2, scale2), into the sequence Q1 -> DQ2 by removing the middle two nodes. The zero-point
+ * and scale of the final QDQ pair is recomputed to preserve equality to the original sequence.
*/
class DoubleQDQPairsRemover : public GraphTransformer {
public:
@@ -27,28 +24,5 @@ class DoubleQDQPairsRemover : public GraphTransformer {
bool& modified,
int graph_level,
const logging::Logger& logger) const override;
-
- static bool IsNodeRemovable(
- Graph& graph,
- const NodeIndex& self_index,
- NodeIndex& parent_index,
- NodeIndex& child_index,
- NodeIndex& grandchild_index);
-
- template
- static bool FindNewZeroPointAndScale(
- const Graph& graph,
- const Node& node1,
- const Node& node2,
- float& new_scale,
- T& new_zero_point,
- bool& skip_reset);
-
- template
- static void ApplyNewInputValue(
- Graph& graph,
- Node& node,
- const InputIndex& index,
- T value);
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
index d7039cb4b7cfc..0e383c3031ca6 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
+#include
#include "core/mlas/inc/mlas.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h"
@@ -32,7 +33,8 @@ void SplitQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
// create rules for ops that don't change the data
void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
// 3 nodes. DQ, target, Q. Merge into target and remove DQ and Q.
- const std::string action_name{"drop"};
+ const std::string drop_action_name{"drop"};
+ const std::string drop_action_no_int16_name{"drop_no_int16_support"};
NTO::NodeLocation dq{NTO::NodeType::kInput, 0};
NTO::NodeLocation q{NTO::NodeType::kOutput, 0};
@@ -42,22 +44,33 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
MoveToSlot(dq, ArgType::kInput, 0, ArgType::kInput, 0),
MoveToSlot(q, ArgType::kOutput, 0, ArgType::kOutput, 0)};
- std::unique_ptr action = std::make_unique(std::move(moves));
+ std::unique_ptr drop_action_no_int16 = std::make_unique(
+ std::vector(moves)); // Copy before std::move(moves)
+ std::unique_ptr drop_action = std::make_unique(std::move(moves));
#if !defined(ORT_MINIMAL_BUILD)
- std::unique_ptr selector = std::make_unique();
- qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
+ // Use a separate selector + action that disallows 16-bit types for MaxPool and Resize.
+ // int16 MaxPool is not supported by the ONNX specification.
+ // int16 Resize is not supported by the ORT implementation (although allowed by ONNX).
+ std::unique_ptr selector_disallow_16bit = std::make_unique(false);
+ qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name,
+ {{"MaxPool", {12}},
+ {"Resize", {}}},
+ std::move(selector_disallow_16bit),
+ std::move(drop_action_no_int16));
+
+ std::unique_ptr selector = std::make_unique(true);
+ qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_name,
{{"Gather", {}},
{"Reshape", {}},
{"Transpose", {}},
- {"MaxPool", {12}},
- {"Resize", {}},
{"Squeeze", {}},
{"Unsqueeze", {}}},
std::move(selector),
- std::move(action));
+ std::move(drop_action));
#else
- qdq_selector_action_registry.RegisterAction(action_name, std::move(action));
+ qdq_selector_action_registry.RegisterAction(drop_action_no_int16_name, std::move(drop_action_no_int16));
+ qdq_selector_action_registry.RegisterAction(drop_action_name, std::move(drop_action));
#endif
}
@@ -74,6 +87,7 @@ void DropDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::unique_ptr action = std::make_unique(std::move(moves));
#if !defined(ORT_MINIMAL_BUILD)
+ // TODO: Enable 16-bit types in selector when ArgMax supports 16-bit integer input tensors.
std::unique_ptr selector = std::make_unique();
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"ArgMax", {}}},
@@ -91,6 +105,7 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::unique_ptr action = std::make_unique(kMSDomain);
#if !defined(ORT_MINIMAL_BUILD)
+ // TODO: Enable 16-bit types in selector when unary QLinear* ops support 16-bit.
std::unique_ptr selector = std::make_unique();
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"AveragePool", {}},
@@ -112,6 +127,7 @@ void BinaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::unique_ptr action = std::make_unique(kMSDomain);
#if !defined(ORT_MINIMAL_BUILD)
+ // TODO: Enable 16-bit types in selector when binary QLinear* ops support 16-bit.
std::unique_ptr selector = std::make_unique();
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Add", {}},
@@ -131,6 +147,7 @@ void VariadicOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::unique_ptr action = std::make_unique(kMSDomain);
#if !defined(ORT_MINIMAL_BUILD)
+ // TODO: Enable 16-bit types in selector when QLinearConcat supports 16-bit.
std::unique_ptr selector = std::make_unique();
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
@@ -152,6 +169,7 @@ void ConvQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool is_
std::unique_ptr action = std::make_unique();
#if !defined(ORT_MINIMAL_BUILD)
+ // TODO: Enable 16-bit types in selector when QLinearConv supports 16-bit.
std::unique_ptr selector = std::make_unique(is_int8_allowed);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
@@ -174,6 +192,7 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i
std::unique_ptr action = std::make_unique();
#if !defined(ORT_MINIMAL_BUILD)
+ // TODO: Enable 16-bit types in selector when QLinearMatMul and MatMulInteger support 16-bit.
std::unique_ptr selector = std::make_unique(is_int8_allowed);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"MatMul", {}}},
@@ -195,6 +214,7 @@ void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::unique_ptr action = std::make_unique();
#if !defined(ORT_MINIMAL_BUILD)
+ // TODO: Enable 16-bit types in selector when QGemm supports 16-bit.
std::unique_ptr selector = std::make_unique();
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Gemm", {}}},
@@ -215,6 +235,7 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::unique_ptr action = std::make_unique();
#if !defined(ORT_MINIMAL_BUILD)
+ // TODO: Enable 16-bit types in selector when QLinearWhere supports 16-bit.
std::unique_ptr selector = std::make_unique();
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Where", {}}},
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
index 02a7fb733813c..16c7bd5fce960 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
@@ -14,6 +14,12 @@
namespace onnxruntime {
namespace QDQ {
namespace {
+
+constexpr bool Is16BitIntType(int32_t data_type) {
+ return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16) ||
+ (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16);
+}
+
// adjust for an optional input/output that has an entry but does not exist
int NumActualValues(const Node& node, bool input) {
const auto& defs = input ? node.InputDefs() : node.OutputDefs();
@@ -110,6 +116,17 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}
+ int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
+ int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
+
+ if (dt_input != dt_output) {
+ return false;
+ }
+
+ if (!allow_16bit_ && Is16BitIntType(dt_input)) {
+ return false;
+ }
+
const Node& dq_node = *dq_nodes.front();
const Node& q_node = *q_nodes.front();
@@ -124,7 +141,7 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector& dq_nodes,
const std::vector& q_nodes) const {
- int num_dq_inputs = NumActualValues(node, true);
+ constexpr int num_dq_inputs = 1;
if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) {
return false;
}
@@ -136,6 +153,12 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
(void)q_nodes;
const Node& dq_node = *dq_nodes.front();
+ const int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
+
+ // 16-bit int types must be explicitly allowed.
+ if (!allow_16bit_ && Is16BitIntType(dt_input)) {
+ return false;
+ }
auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) {
return graph_viewer.GetConstantInitializer(initializer_name, true);
@@ -154,7 +177,16 @@ bool UnaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node&
int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
- return dt_input == dt_output;
+ if (dt_input != dt_output) {
+ return false;
+ }
+
+ // 16-bit int types must be explicitly allowed.
+ if (!allow_16bit_ && Is16BitIntType(dt_input)) {
+ return false;
+ }
+
+ return true;
}
bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer,
@@ -168,8 +200,18 @@ bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer,
int32_t dt_input_1 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_input_2 = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
- return dt_input_1 == dt_input_2 &&
- dt_input_1 == dt_output;
+
+ // All input and output types must match.
+ if (dt_input_1 != dt_input_2 || dt_input_1 != dt_output) {
+ return false;
+ }
+
+ // 16-bit int types must be explicitly allowed.
+ if (!allow_16bit_ && Is16BitIntType(dt_input_1)) {
+ return false;
+ }
+
+ return true;
}
bool VariadicNodeGroupSelector::Check(const GraphViewer& graph_viewer,
@@ -194,7 +236,17 @@ bool VariadicNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}
}
- return dt_input == dt_output;
+
+ if (dt_input != dt_output) {
+ return false;
+ }
+
+ // 16-bit int types must be explicitly allowed.
+ if (!allow_16bit_ && Is16BitIntType(dt_input)) {
+ return false;
+ }
+
+ return true;
}
void InputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
@@ -227,12 +279,19 @@ bool ConvNodeGroupSelector::Check(const GraphViewer& graph_viewer,
}
}
- if (dq_nodes.size() < 3) { // no bias
- return true;
+ if (dq_nodes.size() == 3) { // has bias
+ int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
+ if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) {
+ return false;
+ }
}
- int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
- return dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
+ // 16-bit int types must be explicitly allowed.
+ if (!allow_16bit_ && (Is16BitIntType(dt_input) || Is16BitIntType(dt_weight))) {
+ return false;
+ }
+
+ return true;
}
void ConvSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
@@ -256,6 +315,11 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer,
}
}
+ // 16-bit int types must be explicitly allowed.
+ if (!allow_16bit_ && (Is16BitIntType(dt_input) || Is16BitIntType(dt_weight))) {
+ return false;
+ }
+
// potential match for QLinearMatMul or MatMulIntegerToFloat
bool qlinear = !q_nodes.empty();
@@ -299,6 +363,11 @@ bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer,
}
}
+ // 16-bit int types must be explicitly allowed.
+ if (!allow_16bit_ && (Is16BitIntType(dt_A) || Is16BitIntType(dt_B))) {
+ return false;
+ }
+
if (dq_nodes.size() < 3) { // no bias
return true;
}
@@ -326,8 +395,18 @@ bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node&
const int32_t dt_input_1 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
const int32_t dt_input_2 = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
const int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
- return dt_input_1 == dt_input_2 &&
- dt_input_1 == dt_output;
+
+ // All input and output types must match.
+ if (dt_input_1 != dt_input_2 || dt_input_1 != dt_output) {
+ return false;
+ }
+
+ // 16-bit int types must be explicitly allowed.
+ if (!allow_16bit_ && Is16BitIntType(dt_input_1)) {
+ return false;
+ }
+
+ return true;
}
bool PadNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node,
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
index 58ebf81508962..d8fefdd8dc3d9 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
@@ -52,45 +52,75 @@ class NodeGroupSelector {
// Single DQ -> node that does not change data -> Q.
// Zero point and scale are constant scalars and must match
class DropQDQNodeGroupSelector : public NodeGroupSelector {
+ public:
+ explicit DropQDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+
+ private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector& dq_nodes,
const std::vector& q_nodes) const override;
+
+ bool allow_16bit_;
};
// Single DQ -> node.
class DropDQNodeGroupSelector : public NodeGroupSelector {
+ public:
+ explicit DropDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+
+ private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector& dq_nodes,
const std::vector& q_nodes) const override;
+
+ bool allow_16bit_;
};
// single input. default is to only support uint8.
class UnaryNodeGroupSelector : public NodeGroupSelector {
+ public:
+ explicit UnaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+
+ private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector& dq_nodes,
const std::vector& q_nodes) const override;
+
+ bool allow_16bit_;
};
// 2 DQ nodes providing input -> node -> Q
class BinaryNodeGroupSelector : public NodeGroupSelector {
+ public:
+ explicit BinaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+
+ private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector& dq_nodes,
const std::vector& q_nodes) const override;
+
+ bool allow_16bit_;
};
// Variadic DQ nodes -> node -> Q
class VariadicNodeGroupSelector : public NodeGroupSelector {
+ public:
+ explicit VariadicNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector& dq_nodes,
const std::vector& q_nodes) const override;
+
+ bool allow_16bit_;
};
// DQ nodes for X, W and optionally B -> node -> Q
class ConvNodeGroupSelector : public NodeGroupSelector {
public:
// default to 'true'
- ConvNodeGroupSelector(bool int8_allowed = true) : int8_allowed_(int8_allowed) {}
+ ConvNodeGroupSelector(bool int8_allowed = true, bool allow_16bit = true)
+ : int8_allowed_(int8_allowed), allow_16bit_(allow_16bit) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
@@ -98,16 +128,20 @@ class ConvNodeGroupSelector : public NodeGroupSelector {
const std::vector& q_nodes) const override;
bool int8_allowed_;
+ bool allow_16bit_;
};
class WhereNodeGroupSelector : public NodeGroupSelector {
public:
- WhereNodeGroupSelector() = default;
+ explicit WhereNodeGroupSelector(bool allow_16bit = true)
+ : allow_16bit_(allow_16bit) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector& dq_nodes,
const std::vector& q_nodes) const override;
+
+ bool allow_16bit_;
};
class PadNodeGroupSelector : public NodeGroupSelector {
@@ -125,9 +159,11 @@ class PadNodeGroupSelector : public NodeGroupSelector {
class MatMulNodeGroupSelector : public NodeGroupSelector {
public:
MatMulNodeGroupSelector(bool int8_allowed = true,
- bool matmulintegertofloat_allowed = false)
+ bool matmulintegertofloat_allowed = false,
+ bool allow_16bit = true)
: int8_allowed_(int8_allowed),
- matmulintegertofloat_allowed_(matmulintegertofloat_allowed) {
+ matmulintegertofloat_allowed_(matmulintegertofloat_allowed),
+ allow_16bit_(allow_16bit) {
}
private:
@@ -136,15 +172,21 @@ class MatMulNodeGroupSelector : public NodeGroupSelector {
const std::vector& q_nodes) const override;
bool int8_allowed_;
bool matmulintegertofloat_allowed_;
+ bool allow_16bit_;
};
// Input: DQ nodes for A, B and optional C
// Output: optional Q node for Y
class GemmNodeGroupSelector : public NodeGroupSelector {
+ public:
+ explicit GemmNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector& dq_nodes,
const std::vector& q_nodes) const override;
+
+ bool allow_16bit_;
};
// Input: DQ nodes for input, scale, and B
@@ -207,28 +249,33 @@ class BaseSelector : public NodeSelector {
class DropQDQNodesSelector : public BaseSelector {
public:
- DropQDQNodesSelector() : BaseSelector(std::make_unique()) {}
+ explicit DropQDQNodesSelector(bool allow_16bit = false)
+ : BaseSelector(std::make_unique(allow_16bit)) {}
};
class DropDQNodesSelector : public BaseSelector {
public:
- DropDQNodesSelector() : BaseSelector(std::make_unique()) {}
+ explicit DropDQNodesSelector(bool allow_16bit = false)
+ : BaseSelector(std::make_unique(allow_16bit)) {}
};
class UnarySelector : public BaseSelector {
public:
- UnarySelector() : BaseSelector(std::make_unique()) {}
+ explicit UnarySelector(bool allow_16bit = false)
+ : BaseSelector(std::make_unique(allow_16bit)) {}
};
class BinarySelector : public BaseSelector {
public:
- BinarySelector() : BaseSelector(std::make_unique()) {}
+ explicit BinarySelector(bool allow_16bit = false)
+ : BaseSelector(std::make_unique(allow_16bit)) {}
};
// Variadic DQ nodes -> node -> Q
class InputVariadicSelector : public BaseSelector {
public:
- InputVariadicSelector() : BaseSelector(std::make_unique()) {}
+ explicit InputVariadicSelector(bool allow_16bit = false)
+ : BaseSelector(std::make_unique(allow_16bit)) {}
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};
@@ -244,46 +291,36 @@ class OutputVariadicSelector : public BaseSelector {
// DQ nodes for X, W and optionally B -> node -> Q
class ConvSelector : public BaseSelector {
public:
- ConvSelector(bool int8_allowed = false) : BaseSelector(std::make_unique(int8_allowed)) {}
+ ConvSelector(bool int8_allowed = false, bool allow_16bit = false)
+ : BaseSelector(std::make_unique(int8_allowed, allow_16bit)) {}
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};
+
class WhereSelector : public BaseSelector {
public:
- WhereSelector() : BaseSelector(std::make_unique()) {}
+ explicit WhereSelector(bool allow_16bit = false)
+ : BaseSelector(std::make_unique(allow_16bit)) {}
};
+
// 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not
class MatMulSelector : public BaseSelector {
public:
- MatMulSelector(bool int8_allowed)
- : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true)) {}
+ MatMulSelector(bool int8_allowed, bool allow_16bit = false)
+ : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true,
+ allow_16bit)) {}
};
// Input: DQ nodes for A, B and optional C
// Output: optional Q node for Y
class GemmSelector : public BaseSelector {
public:
- GemmSelector()
- : BaseSelector(std::make_unique()) {}
+ explicit GemmSelector(bool allow_16bit = false)
+ : BaseSelector(std::make_unique(allow_16bit)) {}
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};
-// Input: DQ nodes for input, scale, and B (bias)
-// Output: Q node for output
-class InstanceNormalizationSelector : public BaseSelector {
- public:
- InstanceNormalizationSelector()
- : BaseSelector(std::make_unique()) {}
-};
-
-// DQ nodes for X, W and optionally B, (mean, var not required) -> node -> Q
-class BatchNormalizationSelector : public BaseSelector {
- public:
- BatchNormalizationSelector(bool int8_allowed = false)
- : BaseSelector(std::make_unique(int8_allowed)) {}
-};
-
} // namespace QDQ
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc
index 3723ee6032582..2c11bf144999e 100644
--- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc
+++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc
@@ -1195,7 +1195,7 @@ bool TransposeQuantizeDequantizeAxis(const api::GraphRef& graph, const std::vect
static bool HandleQuantizeDequantizeAxis(const api::GraphRef& graph, const std::vector& perm,
api::NodeRef& node, int64_t opset) {
if (opset < 13) {
- // no `axis` value until opset 13
+ // no `axis` attribute until opset 13
return true;
}
diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc
index 67a9a5991939a..a0d75e8cc0e69 100644
--- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc
+++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc
@@ -5,13 +5,47 @@
#include "core/framework/element_type_lists.h"
#include "core/framework/float8.h"
#include "core/framework/float16.h"
-#include "core/providers/cpu/quantization/quantize_linear.h"
+#include "core/framework/op_kernel.h"
#include "core/providers/common.h"
#include "core/mlas/inc/mlas.h"
#include "core/util/qmath.h"
namespace onnxruntime {
+template
+class DequantizeLinear final : public OpKernel {
+ public:
+ explicit DequantizeLinear(const OpKernelInfo& info) : OpKernel(info) {
+ if (!info.GetAttr("axis", &axis_).IsOK()) {
+ axis_ = 1;
+ }
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ int64_t axis_;
+};
+
+template
+class QuantizeLinear final : public OpKernel {
+ public:
+ explicit QuantizeLinear(const OpKernelInfo& info) : OpKernel(info) {
+ if (!info.GetAttr("axis", &axis_).IsOK()) {
+ axis_ = 1;
+ }
+ if (!info.GetAttr("saturate", &saturate_).IsOK()) {
+ saturate_ = 1;
+ }
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ int64_t axis_;
+ int64_t saturate_;
+};
+
static void PrepareForQDQ(const TensorShape& input_shape,
const Tensor& scale,
const Tensor* zero_point_ptr,
@@ -86,6 +120,59 @@ REGISTER_DEQUANTIZELINEAR_VERSIONED(int8_t)
REGISTER_DEQUANTIZELINEAR_VERSIONED(uint8_t)
REGISTER_DEQUANTIZELINEAR_VERSIONED(int32_t)
+#if !defined(DISABLE_CONTRIB_OPS)
+namespace contrib {
+
+// Register alternate MS domain versions of the DequantizeLinear kernel.
+// The MS domain versions additionally support 16-bit integer quantization types.
+ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
+ DequantizeLinear,
+ 1,
+ uint8_t,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ DequantizeLinear);
+
+ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
+ DequantizeLinear,
+ 1,
+ int8_t,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ DequantizeLinear);
+
+ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
+ DequantizeLinear,
+ 1,
+ uint16_t,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ DequantizeLinear);
+
+ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
+ DequantizeLinear,
+ 1,
+ int16_t,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ DequantizeLinear);
+
+ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
+ DequantizeLinear,
+ 1,
+ int32_t,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ DequantizeLinear);
+
+} // namespace contrib
+#endif // !defined(DISABLE_CONTRIB_OPS)
+
template
struct DequantizeLinearApply {
void op(int64_t N, int64_t broadcast_dim, int64_t block_size, const T* input, const OutT* scale, OutT* output, const T* zero_point) {
@@ -220,6 +307,49 @@ REGISTER_QUANTIZELINEAR(Float8E5M2FNUZ)
REGISTER_QUANTIZELINEAR_VERSIONED(int8_t)
REGISTER_QUANTIZELINEAR_VERSIONED(uint8_t)
+#if !defined(DISABLE_CONTRIB_OPS)
+namespace contrib {
+
+// Register alternate MS domain versions of the QuantizeLinear kernel.
+// The MS domain versions additionally support 16-bit integer quantization types.
+ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
+ QuantizeLinear,
+ 1,
+ uint8_t,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ QuantizeLinear);
+
+ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
+ QuantizeLinear,
+ 1,
+ int8_t,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ QuantizeLinear);
+
+ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
+ QuantizeLinear,
+ 1,
+ uint16_t,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ QuantizeLinear);
+
+ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
+ QuantizeLinear,
+ 1,
+ int16_t,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ QuantizeLinear);
+} // namespace contrib
+#endif // !defined(DISABLE_CONTRIB_OPS)
+
template
void ParQuantizeLinear(const InputType* Input,
OutputType* Output,
@@ -279,5 +409,4 @@ Status QuantizeLinear::Compute(OpKernelContext* ctx) const {
return Status::OK();
}
-
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.h b/onnxruntime/core/providers/cpu/quantization/quantize_linear.h
deleted file mode 100644
index 60e9d09665ab2..0000000000000
--- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.h
+++ /dev/null
@@ -1,45 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include "core/common/common.h"
-#include "core/framework/op_kernel.h"
-#include "core/util/math_cpuonly.h"
-
-namespace onnxruntime {
-
-template
-class DequantizeLinear final : public OpKernel {
- public:
- DequantizeLinear(const OpKernelInfo& info) : OpKernel(info) {
- if (!info.GetAttr("axis", &axis_).IsOK()) {
- axis_ = 1;
- }
- }
-
- Status Compute(OpKernelContext* context) const override;
-
- private:
- int64_t axis_;
-};
-
-template
-class QuantizeLinear final : public OpKernel {
- public:
- QuantizeLinear(const OpKernelInfo& info) : OpKernel(info) {
- if (!info.GetAttr("axis", &axis_).IsOK()) {
- axis_ = 1;
- }
- if (!info.GetAttr("saturate", &saturate_).IsOK()) {
- saturate_ = 1;
- }
- }
-
- Status Compute(OpKernelContext* context) const override;
-
- private:
- int64_t axis_;
- int64_t saturate_;
-};
-} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc
index 556a86bb1519b..8081033c35618 100644
--- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc
+++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc
@@ -30,6 +30,12 @@ class SimpleOpBuilder : public BaseOpBuilder {
private:
Status ExplicitOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const;
+ Status ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wrapper,
+ const NodeUnit& node_unit,
+ std::vector&& input_names,
+ std::vector&& param_tensor_names,
+ const logging::Logger& logger,
+ bool do_op_validation) const ORT_MUST_USE_RESULT;
static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"};
static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"};
@@ -279,10 +285,120 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
ORT_RETURN_IF_ERROR(ProcessGridSampleAttributes(qnn_model_wrapper, node_unit, param_tensor_names));
}
- ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit,
- std::move(input_names),
- std::move(param_tensor_names),
- logger, do_op_validation, GetQnnOpType(op_type)));
+ if (op_type == "Sigmoid" || op_type == "Tanh") {
+ // QNN requires 16-bit QDQ Sigmoid and Tanh to use specific output scale and zero-point values
+ // regardless of floating-point range.
+ return ProcessSigmoidOrTanhOutput(qnn_model_wrapper,
+ node_unit,
+ std::move(input_names),
+ std::move(param_tensor_names),
+ logger, do_op_validation);
+ }
+
+ return ProcessOutputs(qnn_model_wrapper, node_unit,
+ std::move(input_names),
+ std::move(param_tensor_names),
+ logger, do_op_validation, GetQnnOpType(op_type));
+}
+
+/**
+ * Overrides offset and scale quantization parameters for operators (e.g., Sigmoid or Tanh) that require
+ * specific values. Returns true if the quantization parameters were overridden.
+ *
+ * \param op_type The ONNX operator type.
+ * \param qnn_data_type The QNN tensor data type.
+ * \param quant_params Output scale/offset parameter that may be overridden.
+ * \return True if the offset and scale were overridden.
+ */
+static bool OverrideQuantParams(const std::string& op_type, Qnn_DataType_t qnn_data_type,
+ Qnn_ScaleOffset_t& quant_params) {
+ const int32_t orig_offset = quant_params.offset;
+ const float orig_scale = quant_params.scale;
+
+ if (op_type == "Sigmoid") {
+ switch (qnn_data_type) {
+ case QNN_DATATYPE_UFIXED_POINT_16:
+ quant_params.offset = 0;
+ quant_params.scale = 1.0f / 65536.0f;
+ break;
+ case QNN_DATATYPE_SFIXED_POINT_16:
+ quant_params.offset = 0;
+ quant_params.scale = 1.0f / 32768.0f;
+ break;
+ default:
+ break; // Do nothing.
+ }
+ }
+
+ if (op_type == "Tanh") {
+ switch (qnn_data_type) {
+ case QNN_DATATYPE_UFIXED_POINT_16:
+ quant_params.offset = -32768;
+ quant_params.scale = 1.0f / 32768.0f;
+ break;
+ case QNN_DATATYPE_SFIXED_POINT_16:
+ quant_params.offset = 0;
+ quant_params.scale = 1.0f / 32768.0f;
+ break;
+ default:
+ break; // Do nothing.
+ }
+ }
+
+ return quant_params.offset != orig_offset || quant_params.scale != orig_scale;
+}
+
+/**
+ * Processes the output for Sigmoid or Tanh operators and creates the corresponding QNN operator.
+ * These operator types are handled separately because QNN requires 16-bit QDQ Sigmoid and Tanh operators to use
+ * specific scale and zero-point values regardless of floating-point range.
+ *
+ * \param qnn_model_wrapper The QNN model wrapper object.
+ * \param node_unit The QDQ node unit for the Sigmoid or Tanh node.
+ * \param input_names List of input names.
+ * \param param_tensor_names List of param tensor names.
+ * \param logger Logger used to report information.
+ * \param do_op_validation True if the new QNN node should be validated.
+ */
+Status SimpleOpBuilder::ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wrapper,
+ const NodeUnit& node_unit,
+ std::vector&& input_names,
+ std::vector&& param_tensor_names,
+ const logging::Logger& logger,
+ bool do_op_validation) const {
+ const std::string& op_type = node_unit.OpType();
+ const auto& output = node_unit.Outputs()[0];
+ const std::string& output_name = output.node_arg.Name();
+
+ OnnxInputInfo output_info = {};
+
+ // TODO(adrianlizarraga): Rename GetOnnxInputInfo() since it can be used for outputs as well.
+ ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(output, output_info));
+
+ if (output_info.quant_param.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
+ if (OverrideQuantParams(op_type, output_info.qnn_data_type, output_info.quant_param.scaleOffsetEncoding)) {
+ const int32_t offset = output_info.quant_param.scaleOffsetEncoding.offset;
+ const float scale = output_info.quant_param.scaleOffsetEncoding.scale;
+
+ LOGS(logger, VERBOSE) << "QNN requires that 16-bit quantized " << op_type << " operators use offset/scale values "
+ << "of <" << offset << ", " << scale << ">. QNN EP will override the original values.";
+ }
+ }
+
+ Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ
+ : QNN_TENSOR_TYPE_NATIVE;
+ QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, output_info.qnn_data_type, output_info.quant_param,
+ std::move(output_info.shape));
+ ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
+ ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit),
+ QNN_OP_PACKAGE_NAME_QTI_AISW,
+ GetQnnOpType(op_type),
+ std::move(input_names),
+ {output_name},
+ std::move(param_tensor_names),
+ do_op_validation),
+ "Failed to add node.");
+
return Status::OK();
}
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc
index eebe75d839b12..9d339387b0a43 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc
+++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc
@@ -301,6 +301,16 @@ bool QnnModelWrapper::ProcessOffset(const std::string& offset_name,
offset_value = 0 - (uint8_span.data()[0]);
break;
}
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
+ auto uint16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor));
+ offset_value = -static_cast(uint16_span.data()[0]);
+ break;
+ }
+ case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
+ auto int16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor));
+ offset_value = -static_cast(int16_span.data()[0]);
+ break;
+ }
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
auto int32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor));
offset_value = -(int32_span.data()[0]);
diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py
index 924d4c72b6390..2d1e418f9d2b4 100644
--- a/onnxruntime/python/tools/quantization/onnx_quantizer.py
+++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py
@@ -104,7 +104,7 @@ def __init__(
)
self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"]
self.is_weight_symmetric = (
- weight_qType in (QuantType.QInt8, QuantType.QFLOAT8E4M3FN)
+ weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN)
if "WeightSymmetric" not in self.extra_options
else self.extra_options["WeightSymmetric"]
)
diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py
index f87a9d8228bac..e595b580b20df 100644
--- a/onnxruntime/python/tools/quantization/qdq_quantizer.py
+++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py
@@ -25,6 +25,7 @@
add_quant_output_suffix,
add_quant_suffix,
find_by_name,
+ ms_domain,
)
from .registry import CreateQDQQuantizer
@@ -119,6 +120,20 @@ def __init__(
else extra_options["QDQOpTypePerChannelSupportToAxis"]
)
+ self.qdq_op_domain = ms_domain if extra_options.get("UseQDQContribOps", False) else None
+
+ # The ONNX spec does not yet support 16-bit Q/DQ ops. So, must override the Q/DQ op domain to 'com.microsoft'
+ # if the activation or weight types are 16-bit integers.
+ # TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support.
+ int16_types = (TensorProto.UINT16, TensorProto.INT16)
+ if not self.qdq_op_domain and (self.activation_qType in int16_types or self.weight_qType in int16_types):
+ logging.warning(
+ "ONNX QuantizeLinear and DequantizeLinear operators do not support 16-bit integer quantization types. "
+ f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to "
+ "enable support."
+ )
+ self.qdq_op_domain = ms_domain
+
def _is_tensor_quantizable(self, tensor_name):
"""
Check if tensor can be quantized
@@ -249,6 +264,7 @@ def _create_qdq_nodes(
[q_output],
quant_node_name,
axis=axis,
+ domain=self.qdq_op_domain,
)
dequant_node = onnx.helper.make_node(
DEQUANT_OP_NAME,
@@ -256,6 +272,7 @@ def _create_qdq_nodes(
[dq_output],
dequant_node_name,
axis=axis,
+ domain=self.qdq_op_domain,
)
self.model.add_nodes([qlinear_node, dequant_node])
@@ -300,6 +317,7 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None):
[weight_dequant_output],
add_dequant_suffix(weight_name),
axis=axis,
+ domain=self.qdq_op_domain,
)
self.model.add_node(dequant_node)
@@ -443,6 +461,7 @@ def _quantize_bias_tensors(self):
[bias_name],
node_name,
axis=quant_value.axis,
+ domain=self.qdq_op_domain,
)
else:
dequant_node = onnx.helper.make_node(
@@ -450,6 +469,7 @@ def _quantize_bias_tensors(self):
inputs,
[bias_name],
node_name,
+ domain=self.qdq_op_domain,
)
else:
raise RuntimeError(f"Unexpected operator type {quant_value.node_type!r}.")
diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py
index 4d5bcca29618f..74e54c3f1fa37 100644
--- a/onnxruntime/python/tools/quantization/quant_utils.py
+++ b/onnxruntime/python/tools/quantization/quant_utils.py
@@ -72,6 +72,8 @@ class QuantType(Enum):
QInt8 = 0
QUInt8 = 1
QFLOAT8E4M3FN = 2
+ QInt16 = 3
+ QUInt16 = 4
def __str__(self):
return self.name
@@ -89,6 +91,10 @@ def tensor_type(self):
return TensorProto.INT8
if self == QuantType.QUInt8:
return TensorProto.UINT8
+ if self == QuantType.QUInt16:
+ return TensorProto.UINT16
+ if self == QuantType.QInt16:
+ return TensorProto.INT16
if self == QuantType.QFLOAT8E4M3FN:
return TensorProto.FLOAT8E4M3FN
raise ValueError(f"Unexpected value qtype={self!r}.")
@@ -112,12 +118,35 @@ def from_string(format):
ONNX_TYPE_TO_NP_TYPE = {
onnx_proto.TensorProto.INT8: numpy.dtype("int8"),
onnx_proto.TensorProto.UINT8: numpy.dtype("uint8"),
+ onnx_proto.TensorProto.INT16: numpy.dtype("int16"),
+ onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"),
onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn,
}
+ONNX_INT_TYPE_RANGE = {
+ onnx_proto.TensorProto.UINT8: (0, 255),
+ onnx_proto.TensorProto.INT8: (-128, 127),
+ onnx_proto.TensorProto.UINT16: (0, 65535),
+ onnx_proto.TensorProto.INT16: (-32768, 32767),
+}
+
+ONNX_INT_TYPE_SYMMETRIC_RANGE = {
+ onnx_proto.TensorProto.INT8: (-127, 127),
+ onnx_proto.TensorProto.INT16: (-32767, 32767),
+}
+
+ONNX_INT_TYPE_REDUCED_RANGE = {
+ onnx_proto.TensorProto.UINT8: (0, 127),
+ onnx_proto.TensorProto.INT8: (-64, 64),
+ onnx_proto.TensorProto.UINT16: (0, 32767),
+ onnx_proto.TensorProto.INT16: (-16384, 16384),
+}
+
def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
- assert qType in ONNX_TYPE_TO_NP_TYPE, f"Unexpected data type {qType} requested. Only INT8 and UINT8 are supported."
+ assert (
+ qType in ONNX_TYPE_TO_NP_TYPE
+ ), f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported."
if qType in (
onnx_proto.TensorProto.FLOAT8E4M3FN,
onnx_proto.TensorProto.FLOAT8E4M3FNUZ,
@@ -146,8 +175,10 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
return ref.run(None, {"X": arr.astype(numpy.float32), "scale": scale.astype(numpy.float32)})[0]
else:
dtype = ONNX_TYPE_TO_NP_TYPE[qType]
- cliplow = max(0 if dtype == numpy.uint8 else -127, -127 if low is None else low)
- cliphigh = min(255 if dtype == numpy.uint8 else 127, 255 if high is None else high)
+ (qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True)
+
+ cliplow = max(qmin, low) if low is not None else qmin
+ cliphigh = min(qmax, high) if high is not None else qmax
arr_fp32 = numpy.asarray((arr.astype(numpy.float32) / scale).round() + zero_point)
numpy.clip(arr_fp32, cliplow, cliphigh, out=arr_fp32)
return arr_fp32.astype(dtype)
@@ -267,7 +298,7 @@ def quantize_data(data, qType, symmetric, reduce_range=False):
)
return rmin, rmax, zero_point, scale, quantized_data
- if qType in (TensorProto.INT8, TensorProto.UINT8):
+ if qType in (TensorProto.INT8, TensorProto.UINT8, TensorProto.INT16, TensorProto.UINT16):
if len(data):
qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric)
zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric)
@@ -283,18 +314,22 @@ def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False): # noqa
:parameter qType: onnx.onnx_pb.TensorProto.UINT8 or onnx.onnx_pb.TensorProto.UINT8
:return: qmin, qmax
"""
- if qType == onnx_proto.TensorProto.UINT8:
- (qmin, qmax) = (0, 127) if reduce_range else (0, 255)
- elif qType == onnx_proto.TensorProto.INT8:
- if symmetric:
- (qmin, qmax) = (-64, 64) if reduce_range else (-127, 127)
- else:
- (qmin, qmax) = (-64, 64) if reduce_range else (-128, 127)
- elif qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
+ if qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
raise NotImplementedError("This function is not implemented for float 8 as not needed.")
+
+ qrange = None
+
+ if reduce_range:
+ qrange = ONNX_INT_TYPE_REDUCED_RANGE.get(qType)
+ elif symmetric and qType in ONNX_INT_TYPE_SYMMETRIC_RANGE:
+ qrange = ONNX_INT_TYPE_SYMMETRIC_RANGE[qType]
else:
- raise ValueError(f"Unexpected data type {qType} requested. Only INT8 and UINT8 are supported.")
- return qmin, qmax
+ qrange = ONNX_INT_TYPE_RANGE.get(qType)
+
+ if not qrange:
+ raise ValueError(f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported.")
+
+ return qrange
def get_qrange_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802
diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py
index 6b1646aec9679..706047fe32400 100644
--- a/onnxruntime/python/tools/quantization/quantize.py
+++ b/onnxruntime/python/tools/quantization/quantize.py
@@ -240,6 +240,11 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua
f"weight_type={weight_type}!=QuantType.QFLOAT8E4M3FN"
)
+ q16_types = [QuantType.QInt16, QuantType.QUInt16]
+
+ if (activation_type in q16_types or weight_type in q16_types) and quant_format != QuantFormat.QDQ:
+ raise ValueError("Only QuantFormat.QDQ supports 16-bit quantization types.")
+
if activation_type == QuantType.QInt8 and weight_type == QuantType.QInt8 and quant_format != QuantFormat.QDQ:
logging.warning(
"Please use QuantFormat.QDQ for activation type QInt8 and weight type QInt8. "
@@ -356,6 +361,11 @@ def quantize_static(
SmoothQuantFolding = True/False :
Default is True. It only works if SmoothQuant is True. If enabled, inserted Mul ops during
SmoothQuant will be folded into the previous op if the previous op is foldable.
+ UseQDQContribOps = True/False :
+ Default is False. If enabled, the inserted QuantizeLinear and DequantizeLinear ops will have the
+ `com.microsoft` domain, which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear
+ contrib op implementations. The contrib op implementations may support features not standardized
+ into the ONNX specification (e.g., 16-bit quantization types).
"""
if activation_type == QuantType.QFLOAT8E4M3FN or weight_type == QuantType.QFLOAT8E4M3FN:
if calibrate_method != CalibrationMethod.Distribution:
diff --git a/onnxruntime/test/contrib_ops/quantize_ops_test.cc b/onnxruntime/test/contrib_ops/quantize_ops_test.cc
index af29f972a64cf..64a97ed4f945b 100644
--- a/onnxruntime/test/contrib_ops/quantize_ops_test.cc
+++ b/onnxruntime/test/contrib_ops/quantize_ops_test.cc
@@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
+#include "test/util/include/default_providers.h"
namespace onnxruntime {
namespace test {
@@ -40,7 +41,31 @@ TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int8) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
-// Scalar zero & scale with int32
+// Test int16 com.microsoft.DequantizeLinear (per tensor)
+TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int16_cpu) {
+ OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain);
+ std::vector dims{4};
+ test.AddInput("x", dims, {-300, -30, -1025, 1270});
+ test.AddInput("scale", {}, {2.0f}, true);
+ test.AddInput("zero_point", {}, {-1024}, true);
+ test.AddOutput("y", dims, {1448.0f, 1988.0f, -2.0f, 4588.0f});
+ // Disable Tensorrt EP due to error: unsupported data type
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+}
+
+// Test uint16 com.microsoft.DequantizeLinear (per tensor)
+TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_uint16_cpu) {
+ OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain);
+ std::vector dims{4};
+ test.AddInput("x", dims, {30000, 31000, 32768, 33000});
+ test.AddInput("scale", {}, {2.0f}, true);
+ test.AddInput("zero_point", {}, {32767}, true);
+ test.AddOutput("y", dims, {-5534.0f, -3534.0f, 2.0f, 466.0f});
+ // Disable Tensorrt EP due to error: unsupported data type
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+}
+
+// Test int32 DequantizeLinear with scalar zero-point & scale.
TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int32_cpu) {
OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain);
std::vector dims{4};
@@ -256,6 +281,60 @@ TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_int8) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
+// Test uint16 com.microsoft.QuantizeLinear (per tensor)
+TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_uint16) {
+ OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain);
+ std::vector dims{12};
+ test.AddInput("x", dims, {
+ 0.f, -128.f, 3.f, -3.f, // rounding half to even
+ 2.9f, -2.9f, // round < .5
+ 3.1f, -3.1f, // round > .5
+ 65536.f, -65534.f, // critical point
+ 70000.f, -70000.f // saturate case
+ });
+ test.AddInput("scale", {}, {2.0f}, true);
+ test.AddInput("zero_point", {}, {32767}, true);
+ test.AddOutput("y", dims,
+ {32767, 32703,
+ 32769, 32765,
+ 32768, 32766,
+ 32769, 32765,
+ 65535, 0,
+ 65535, 0});
+
+ // Disable Tensorrt EP due to error: unsupported data type
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+}
+
+// Test int16 com.microsoft.QuantizeLinear (per tensor)
+TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_int16) {
+ OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain);
+ std::vector dims{16};
+ test.AddInput