diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 25202f82f468d..cf71b6bcf7c7d 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -10,6 +10,9 @@ set(contrib_ops_excluded_files "bert/attention_impl.cu" "bert/attention_softmax.h" "bert/attention_softmax.cu" + "bert/attention_prepare_qkv.cu" + "bert/decoder_attention_impl.h" + "bert/decoder_attention_impl.cu" "bert/decoder_masked_multihead_attention.h" "bert/decoder_masked_multihead_attention.cc" "bert/decoder_masked_self_attention.h" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 5bd1a89c0dea1..95dc8c3cde46c 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1351,8 +1351,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(int8), tensor(uint8), tensor(int32)
-
Constrain 'x' and 'x_zero_point' to 8-bit integer tensors or 32-bit signed integer tensors.
+
T1 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32)
+
Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, 16-bit integer tensors, or 32-bit signed integer tensors.
T2 : tensor(float16), tensor(float)
Constrain 'y', 'x_scale' to float tensors.
@@ -4194,8 +4194,9 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.QuantizeLinear** The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor. - The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8. - For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. + The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, [-128, 127] if it's int8, + [0, 65,535] if it's uint16, and [-32,768, 32,767] if it's int16. For (x / y_scale), it's rounding to nearest ties to even. + Refer to https://en.wikipedia.org/wiki/Rounding for details. Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis'). #### Version @@ -4232,8 +4233,8 @@ This version of the operator has been available since version 1 of the 'com.micr
T1 : tensor(float16), tensor(float)
Constrain 'x', 'y_scale' to float tensors.
-
T2 : tensor(int8), tensor(uint8)
-
Constrain 'y_zero_point' and 'y' to 8-bit integer tensors.
+
T2 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16)
+
Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d46f3ed9bd262..33c187a28b62e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -439,7 +439,7 @@ Do not modify directly.* |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)| -|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float)| +|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)
**T2** = tensor(float)| |DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float)| @@ -472,7 +472,7 @@ Do not modify directly.* |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 19caa69d94ccf..f153e88909b8d 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1135,6 +1135,7 @@ class Graph { /** Directly insert the nodes in the function Node provided into this Graph. + The Graph needs to be Resolve()d after this call. @param node Node with Node::Type of Node::Type::Fused @returns Status indicating success or providing an error message. */ diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 71d98f5d73671..a87a894e3b3c5 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -46,6 +46,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | | LeakyRelu | ai.onnx(6-15,16+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 7e52954734216..f08d7a77d1099 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; type BuiltinFunctionName = string; type ElementwiseCustomExpression = (expression: string) => string; @@ -101,6 +101,9 @@ export const parseCastAttributes = (attributes: Record): CastAt export const cast = (context: ComputeContext, attributes: CastAttributes): void => { let func: ElementwiseFunctionCall; switch (attributes.to) { + case DataType.float16: + func = 'vec4'; + break; case DataType.float: func = 'vec4'; break; @@ -126,11 +129,12 @@ export interface ClipAttributes extends AttributeWithCacheKey { } export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => { + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute( createElementwiseProgramInfoLoader( context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, ` - const clip_min_: vec4 = vec4(f32(${attributes.min})); - const clip_max_: vec4 = vec4(f32(${attributes.max})); + const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min})); + const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max})); `, attributes.cacheKey), {inputs: [0]}); @@ -180,13 +184,13 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void attributes.cacheKey)); }; -export const erfImpl = (dataType: string) => ` -const r0: f32 = 0.3275911; -const r1: f32 = 0.254829592; -const r2: f32 = -0.284496736; -const r3: f32 = 1.421413741; -const r4: f32 = -1.453152027; -const r5: f32 = 1.061405429; +export const erfImpl = (dataType: string, varType = 'f32') => ` +const r0: ${varType} = 0.3275911; +const r1: ${varType} = 0.254829592; +const r2: ${varType} = -0.284496736; +const r3: ${varType} = 1.421413741; +const r4: ${varType} = -1.453152027; +const r5: ${varType} = 1.061405429; fn erf_vf32(v: ${dataType}) -> ${dataType} { let absv = abs(v); @@ -195,8 +199,9 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { }`; export const erf = (context: ComputeContext): void => { - context.compute( - createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4'))); + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfoLoader( + context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); }; export const exp = (context: ComputeContext): void => { @@ -208,9 +213,10 @@ export const floor = (context: ComputeContext): void => { }; export const gelu = (context: ComputeContext): void => { + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfoLoader( context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, - erfImpl('vec4'))); + erfImpl(`vec4<${dataType}>`, dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 7b41850948149..f90f568879146 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -382,8 +382,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs const globalEnvFlags = parseGlobalEnvFlags(args); if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) { - // Backend webnn is restricted in the dedicated worker. - globalEnvFlags.wasm!.proxy = true; + throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.'); } // Options: diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 520ef62b2c719..a75321d45f1ef 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -84,8 +84,10 @@ async function main() { .flat(); for (const backend of DEFAULT_BACKENDS) { - nodeTests.set(backend, loadNodeTests(backend, allNodeTestsFolders)); - opTests.set(backend, loadOpTests(backend)); + if (args.backends.indexOf(backend) !== -1) { + nodeTests.set(backend, loadNodeTests(backend, allNodeTestsFolders)); + opTests.set(backend, loadOpTests(backend)); + } } } diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 94592884ccad6..6e65645ef4756 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -602,6 +602,11 @@ // // "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", + "test_if", + // TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra + // supports Sequence and Optional types + // "test_if_seq", + // "test_if_opt", "test_instancenorm_epsilon", "test_instancenorm_example", // "test_isinf_negative", diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 660c8bd9e0624..0ec5088808656 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -56,9 +56,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLine class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid); @@ -191,9 +195,13 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc b/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc deleted file mode 100644 index 28a304bfc7f0e..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/cpu/quantization/quantize_linear.h" -#include "core/providers/common.h" - -namespace onnxruntime { -namespace contrib { - -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( - DequantizeLinear, - 1, - uint8_t, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - DequantizeLinear); - -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( - DequantizeLinear, - 1, - int8_t, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - DequantizeLinear); - -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( - DequantizeLinear, - 1, - int32_t, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - DequantizeLinear); - -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( - QuantizeLinear, - 1, - uint8_t, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - QuantizeLinear); - -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( - QuantizeLinear, - 1, - int8_t, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - QuantizeLinear); - -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index a79ad96b94d91..f0385ea5abdfb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -249,30 +249,28 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); - data.bias = nullptr == bias ? nullptr : reinterpret_cast(bias->Data()); - data.query = nullptr; - data.key = nullptr; - data.value = nullptr; - data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); - data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); - data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data()); - data.past_key = nullptr; - data.past_value = nullptr; - data.relative_position_bias = (nullptr == relative_position_bias) - ? nullptr - : reinterpret_cast(relative_position_bias->Data()); + if (nullptr != bias) { + data.bias = reinterpret_cast(bias->Data()); + } + if (nullptr != mask_index) { + data.mask_index = mask_index->Data(); + data.mask_index_dims = mask_index->Shape().GetDims(); + } + if (nullptr != past) { + data.past = reinterpret_cast(past->Data()); + } + if (nullptr != relative_position_bias) { + data.relative_position_bias = reinterpret_cast(relative_position_bias->Data()); + } data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); - data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); - data.present_key = nullptr; - data.present_value = nullptr; + if (nullptr != present) { + data.present = reinterpret_cast(present->MutableData()); + } data.fused_runner = reinterpret_cast(fused_runner); - data.fused_cross_attention_kernel = nullptr; data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = nullptr; - data.cumulated_sequence_length_kv_cache = nullptr; return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu b/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu deleted file mode 100644 index 5d9cfcc69773a..0000000000000 --- a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/cuda/cuda_common.h" -#include "contrib_ops/cuda/bert/attention_impl.h" - -using namespace onnxruntime::cuda; - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -template -__global__ void ConcatTensorToTensor(const int tensor_add_sequence_length, - const T* tensor_in, - const T* tensor_add, - T* tensor_out) { - const int h = threadIdx.x; - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; - const int chunk_id = blockIdx.z; - - const int all_sequence_length = gridDim.x; - const int batch_size = gridDim.y; - const int num_heads = blockDim.y; - const int H = blockDim.x; - - // K: number of identical tensors - // tensor_in: K x BxNxPxH - // tensor_add: K x BxNxLxH - // tensor_out: K x BxNxTxH, where T = P + L - const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length; - - const int present_SH = all_sequence_length * H; - const int present_NSH = num_heads * present_SH; - int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); - if (s < tensor_in_sequence_length) { - const int past_SH = tensor_in_sequence_length * H; - const int past_NSH = num_heads * past_SH; - const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); - tensor_out[out_offset] = tensor_in[in_offset]; - } else if (s < all_sequence_length) { - const int SH = tensor_add_sequence_length * H; - const int NSH = num_heads * SH; - const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); - tensor_out[out_offset] = tensor_add[in_offset]; - } -} - -template -__global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length, - const int H, - const T* tensor_in, - const T* tensor_add, - T* tensor_out) { - // Use when (H*)*num_heads > 1024 - int h = threadIdx.x; - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; - const int chunk_id = blockIdx.z; - - const int all_sequence_length = gridDim.x; - const int batch_size = gridDim.y; - const int num_heads = blockDim.y; - const int stride = blockDim.x; - - // K: number of identical tensor - // tensor_in: K x BxNxPxH - // tensor_add: K x BxNxLxH - // tensor_out: K x BxNxTxH - const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length; - - const int present_SH = all_sequence_length * H; - const int present_NSH = num_heads * present_SH; - while (h < H) { - int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); - if (s < tensor_in_sequence_length) { - const int past_SH = tensor_in_sequence_length * H; - const int past_NSH = num_heads * past_SH; - const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); - tensor_out[out_offset] = tensor_in[in_offset]; - } else if (s < all_sequence_length) { - const int SH = tensor_add_sequence_length * H; - const int NSH = num_heads * SH; - const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); - tensor_out[out_offset] = tensor_add[in_offset]; - } - - h += stride; - } -} - -Status LaunchConcatTensorToTensor(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const float* tensor_in, - const float* tensor_add, - float* tensor_out) { - const dim3 grid(all_sequence_length, batch_size, matrix_num); - if (0 == (head_size & 1)) { - const int H = head_size / 2; - if (H * num_heads <= max_threads_per_block) { - const dim3 block(H, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } else { - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, - H, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } - } else { - if (head_size * num_heads <= max_threads_per_block) { - const dim3 block(head_size, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out); - } else { - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, - head_size, - tensor_in, - tensor_add, - tensor_out); - } - } - return CUDA_CALL(cudaGetLastError()); -} - -Status LaunchConcatTensorToTensor(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const half* tensor_in, - const half* tensor_add, - half* tensor_out) { - const dim3 grid(all_sequence_length, batch_size, matrix_num); - if (0 == (head_size % 4)) { - const int H = head_size / 4; - if (H * num_heads <= max_threads_per_block) { - const dim3 block(H, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } else { - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, - H, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } - } else if (0 == (head_size & 1)) { - const int H = head_size / 2; - if (H * num_heads <= max_threads_per_block) { - const dim3 block(H, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } else { - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, - H, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } - } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel. - if (head_size * num_heads <= max_threads_per_block) { - const dim3 block(head_size, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out); - } else { - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, - head_size, - tensor_in, - tensor_add, - tensor_out); - } - } - return CUDA_CALL(cudaGetLastError()); -} - -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* past, - const float* k_v, - float* present) { - return LaunchConcatTensorToTensor( - stream, - all_sequence_length, - sequence_length, - batch_size, - head_size, - num_heads, - max_threads_per_block, - 2, - past, - k_v, - present); -} - -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* past, - const half* k_v, - half* present) { - return LaunchConcatTensorToTensor( - stream, - all_sequence_length, - sequence_length, - batch_size, - head_size, - num_heads, - max_threads_per_block, - 2, - past, - k_v, - present); -} - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index ae7696eb9fe0f..b4a4ae208ceb1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -26,16 +26,11 @@ limitations under the License. // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention_softmax.h" #include "contrib_ops/cuda/bert/transformer_common.h" -#include "contrib_ops/cuda/bert/add_bias_transpose.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" #include "contrib_ops/cpu/bert/attention_base.h" @@ -43,6 +38,7 @@ limitations under the License. #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/attention_impl.h" using namespace onnxruntime::cuda; using namespace onnxruntime::contrib::attention_softmax_cuda; @@ -157,918 +153,285 @@ size_t GetAttentionWorkspaceSize( } template -__global__ void AddBiasTransAppendKvToPresentSmall( - const T* qkv, const T* biases, T* present, - const int head_size, const int past_sequence_length, const int max_sequence_length) { - // Input: BxSxMxNxH (Format 1) - // Output: (2, B, N, [P..P+S) of MaxS, H), - // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; - const int N = blockDim.y; - const int S = gridDim.x; - const int B = gridDim.y; - - constexpr int M = 3; // Matrix count in qkv - const int m = blockIdx.z + 1; // k = 1, v = 2 - - const int NH = N * head_size; - const int NHS = NH * S; - - qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); - if (biases) { - biases += (m * NH + n * head_size); - } +Status FusedTrtCrossAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + assert(data.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; - present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + // We only enable fused cross attention when there is no key padding mask. + // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. + assert(data.mask_index == nullptr); - for (int h = threadIdx.x; h < head_size; h += blockDim.x) { - T bias = (biases ? biases[h] : (T)0.0f); - present[h] = qkv[h] + bias; - } -} + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, + sequence_length, stream, + data.scratch); -template -__global__ void AddBiasTransAppendKvToPresent( - const T* qkv, const T* biases, T* present, - const int head_size, const int past_sequence_length, const int max_sequence_length) { - // Input: BxSxMxNxH (Format 1) - // Output: (2, B, N, [P..P+S) of MaxS, H), - // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size - const int n = blockIdx.x; - const int s = blockIdx.y; - const int b = (blockIdx.z >> 1); - const int N = gridDim.x; - const int S = gridDim.y; - const int B = (gridDim.z >> 1); - - constexpr int M = 3; // Matrix count in qkv - const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 - - const int NH = N * head_size; - const int NHS = NH * S; - - qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); - if (biases) { - biases += (m * NH + n * head_size); - } + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; - present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); + kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, + data.mask_index, batch_size, parameters.kv_sequence_length, stream, + kv_sequence_offset); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); - for (int h = threadIdx.x; h < head_size; h += blockDim.x) { - T bias = (biases ? biases[h] : (T)0.0f); - present[h] = qkv[h] + bias; - } + DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); + + FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = + reinterpret_cast(data.fused_cross_attention_kernel); + + // When there is no bias, we can directly use q and packed kv from inputs. + void const* query = data.q; + void const* packed_kv = data.k; + if (data.value == nullptr && data.bias == nullptr) { + query = data.query; + packed_kv = data.key; + } + + run_fused_cross_attention( + query, // Q + packed_kv, // packed KV + q_sequence_offset, // cumulated sequence length of Q + kv_sequence_offset, // cumulated sequence length of KV + data.output, // output + cross_attention_kernel, // kernels + batch_size, // batch size + parameters.num_heads, // number of heads + parameters.head_size, // head size of Q/K/V + sequence_length, // sequence length of Q + parameters.kv_sequence_length, // sequence length of KV + stream); + + DUMP_TENSOR("trt cross output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + return Status::OK(); } -// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3. -// bias is of shape (3, NxH) or nullptr -// append to present of (2, B, N, (P..T) of M, H), -template -Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int past_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const T* biases, - const T* qkv_buffer, - T* present) { - assert(head_size <= (1 << 30)); - - int64_t nh = (int64_t)head_size * num_heads; - if (nh <= max_threads_per_block) { - const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - - AddBiasTransAppendKvToPresentSmall<<>>( - qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); - } else { - const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v - const dim3 block(std::min(head_size, max_threads_per_block), 1, 1); - AddBiasTransAppendKvToPresent<<>>( - qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); - } - - return CUDA_CALL(cudaGetLastError()); +template <> +Status FusedTrtCrossAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "Trt fused cross attention does not support float tensor"); } -template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int total_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* bias, - const float* qkv_buffer, - float* present); - -template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int total_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* bias, - const half* qkv_buffer, - half* present); - template -Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { +Status FusedTrtSelfAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - const bool past_present_share_buffer = parameters.past_present_share_buffer; - void* fused_runner = data.fused_runner; - bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention; + const bool causal = parameters.is_unidirectional; - T* qkv = data.workspace; + int* sequence_offset = reinterpret_cast(data.scratch); - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - if (data.bias == nullptr) { - assert(nullptr == fused_runner); - // For quantized attention, bias has been added so only need transpose here. - // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH - assert(qk_head_size == v_head_size); - int matrix_to_trans = (past_present_share_buffer ? 1 : 3); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.gemm_buffer, qkv, 3)); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + DUMP_TENSOR_INIT(); + if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); + LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); } else { - // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) - // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) - // For unfused kernel, transpose to 3xBxNxSxH (format 1) - // For fused causal kernel, use format 1 since we need have K and V to update present state, - // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. - const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); - qkv_format = use_fused_kernel - ? AttentionQkvFormat::QKV_BSN3H - : (use_flash_or_efficient_attention - ? AttentionQkvFormat::Q_K_V_BSNH - : (use_fused_causal - ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH - : AttentionQkvFormat::Q_K_V_BNSH)); - - // For fused causal, we will update gemm_buffer with bias directly. - T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; - - int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); - // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v - // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) - LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, - 3, parameters.do_rotary, parameters.past_sequence_length); + sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, sequence_length, stream, + sequence_offset); } - return Status::OK(); -} - -// For MultiHeadAttention with past state -template -Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - - DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); - if (data.bias == nullptr) { - // Below logic does not support fused attention with past without bias - // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. - - // cross attention with past state - if (data.past_key != nullptr && data.present_key == nullptr) { - assert(data.past_value != nullptr); - assert(data.query != nullptr); - assert(data.key == nullptr); - assert(data.value == nullptr); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - } - // cross attention with present state or self attention with present state - else if (data.past_key == nullptr && data.present_key != nullptr) { - assert(data.past_value == nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - - // TODO: supporting packed kv for cross attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, data.present_key)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, data.present_value)); - } - // self attention with past and present state - else { - assert(data.past_key != nullptr); - assert(data.past_value != nullptr); - assert(data.present_key != nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, k)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, v)); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.past_key != nullptr && - data.past_value != nullptr && - parameters.pass_past_in_kv) { - // Transpose past_key and past_value to use memory efficient attention - - // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_key, data.temp_k_workspace)); - // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_value, data.temp_v_workspace)); - - // query => q, temp_k_workspace => k, temp_v_workspace => v - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - - data.past_key = nullptr; - data.past_value = nullptr; - } - // When there is no past_key/past_value and there is present_key/present_value - // (e.g. get initial kv to use as past_kv in the next iteration) - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.present_key != nullptr && - data.present_value != nullptr) { - // Use memory efficient attention kernel - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); - - // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.temp_k_workspace, data.present_key)); - - // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.temp_v_workspace, data.present_value)); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } -#endif - else { - // Use unfused kernel for Q, use unfused kernel for K and V if needed - constexpr int format = 0; - // Query (BxSxNxH) => Q (BxNxSxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - if (!parameters.pass_past_in_kv) { - T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; - T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, data.bias + num_heads * qk_head_size, k_dest, - true, -1); - - // Value (BxLxNxH_v) => V (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, - true, -1); - - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } - return Status::OK(); -} + FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner); -// For MultiHeadAttention without past state, with packed QKV inputs -template -Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; + const int S = causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length); - T* qkv = data.workspace; + // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed. + const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size); - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + fused_fp16_runner->setup(S, B); - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); + if (!causal) { + assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H); - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); - - if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. - constexpr int format = 4; - T* qkv_add_bias = nullptr; - LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, qkv, - true, v_head_size, qkv_add_bias, 3); - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (!use_fused_kernel) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); + // When there is no bias, we can directly use packed qkv from inputs. + void const* packed_qkv = data.q; + if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { + packed_qkv = data.query; } - qkv_format = AttentionQkvFormat::QKV_BSN3H; + fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); + DUMP_TENSOR("fused output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + } else { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); + DUMP_TENSOR("fused causal output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); } return Status::OK(); } -// For MultiHeadAttention without past state, with packed KV inputs +// Template Specialization for float type +template <> +Status FusedTrtSelfAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "Trt fused attention does not support float tensor"); +} + +#if USE_FLASH_ATTENTION template -Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(nullptr == data.mask_index); + assert(nullptr == data.relative_position_bias); + assert(parameters.head_size == parameters.v_head_size); - // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. - // CheckInputs verified this constraint. - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); + void* query = reinterpret_cast(data.q); + void* key = reinterpret_cast(data.k); + void* value = reinterpret_cast(data.v); + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { + query = reinterpret_cast(const_cast(data.query)); + } DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); - - if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack kv to BSNH. Note that there is no bias so we need not output query to q. - constexpr int format = 4; - T* qkv_add_bias = nullptr; - const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); - LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, kv_bias, k, - true, v_head_size, qkv_add_bias, 2); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (data.fused_cross_attention_kernel == nullptr) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); - } + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, + parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, + parameters.batch_size, parameters.total_sequence_length, + parameters.num_heads, parameters.v_head_size); + + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), + parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional)); + + DUMP_TENSOR("flash attention output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } return Status::OK(); } -// For MultiHeadAttention without past state, with Q, K and V inputs -template -Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - // gemm_buffer == nullptr and not packed - assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); - -#if DUMP_TENSOR_LEVEL > 1 - if (data.bias != nullptr) { - DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); - } +template <> +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor"); +} #endif - if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { - DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, - num_heads, sequence_length, kv_sequence_length); - } - - if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { - DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); - } - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qk_head_size == v_head_size); - - // For fused cross attention, besides adding bias, K and V needed to be packed: - // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); - - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - else if (data.use_memory_efficient_attention || data.use_flash_attention) { - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } -#endif - else if (use_fused_kernel) { - assert(qk_head_size == v_head_size); +#if USE_MEMORY_EFFICIENT_ATTENTION +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + // We only enable fused cross attention when there is no key padding mask. + // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } else { // unfused kernel - ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); - - // Query (BxSxNxH) => Q (BxNxSxH) - constexpr int format = 0; - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); - - // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); - - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + const void* query = data.q; + const void* key = data.k; + const void* value = data.v; + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { + assert(data.bias == nullptr); + query = data.query; } - return Status::OK(); -} -template -Status PrepareQkv(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - if (nullptr != data.gemm_buffer) { // Attention operator - ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, qkv_format)); - } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } else { // multihead attention operator, no past, separated Q/K/V inputs - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, + parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, + parameters.batch_size, parameters.total_sequence_length, + parameters.num_heads, parameters.v_head_size); + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = parameters.batch_size; + p.num_heads = parameters.num_heads; + p.sequence_length = parameters.sequence_length; + p.kv_sequence_length = parameters.total_sequence_length; + p.qk_head_size = parameters.head_size; + p.v_head_size = parameters.v_head_size; + p.causal = parameters.is_unidirectional; + p.scale = scale; + p.seqlen_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast(data.mask_index)); + p.seqstart_q_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast( + data.mask_index + parameters.batch_size)); + p.seqstart_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast( + data.mask_index + 2 * parameters.batch_size + 1)); + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; + p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) + ? data.scratch + : nullptr; + p.stream = stream; + run_memory_efficient_attention(p); + DUMP_TENSOR("efficient attention output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); } +#endif template -Status QkvToContext( +Status UnfusedAttention( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, Stream* ort_stream, contrib::AttentionParameters& parameters, - AttentionData& data) { + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH); + auto stream = static_cast(ort_stream->GetHandle()); - constexpr size_t element_size = sizeof(T); - const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; const int total_sequence_length = parameters.total_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - const bool past_present_share_buffer = parameters.past_present_share_buffer; - const float mask_filter_value = parameters.mask_filter_value; - void* fused_runner = data.fused_runner; - - // At most one fused kernel is enabled. - assert((int(data.use_flash_attention) + - int(data.use_memory_efficient_attention) + - int(fused_runner != nullptr) + - int(data.fused_cross_attention_kernel != nullptr)) <= 1); - const int batches = batch_size * num_heads; - T* qkv = nullptr; - T* q = nullptr; - T* k = nullptr; - T* v = nullptr; - T* scratch1 = data.workspace; - if (data.has_qkv_workspace) { - const int size_per_batch_q = sequence_length * qk_head_size; - const int size_per_batch_k = kv_sequence_length * qk_head_size; - const int size_per_batch_v = kv_sequence_length * v_head_size; - const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); - const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); - const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); - qkv = data.workspace; - q = qkv; - k = q + elements_q; - v = k + elements_k; - scratch1 = v + elements_v; - } - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - - int present_size_per_batch_k = 0; - int present_size_per_batch_v = 0; - if (!past_present_share_buffer) { - // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. - // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) - // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) - // When there is past state, the head size for Q/K/V shall be same: H == H_v. - present_size_per_batch_k = total_sequence_length * qk_head_size; - present_size_per_batch_v = total_sequence_length * v_head_size; - - if (nullptr != data.present) { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH || qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); - ORT_RETURN_IF_ERROR( - LaunchConcatPastToPresent( - stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, data.past, k, data.present)); - - // Update pointers to present_k and present_v. - k = data.present; - v = data.present + batches * present_size_per_batch_k; - } - - if (nullptr != data.past_key || nullptr != data.present_key) { - if (nullptr != data.past_key && nullptr == data.present_key) { - k = const_cast(data.past_key); - v = const_cast(data.past_value); - } else if (nullptr == data.past_key && nullptr != data.present_key) { - if (qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { - k = data.present_key; - v = data.present_value; - } else { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - k = data.temp_k_workspace; - v = data.temp_v_workspace; - } - } else if (parameters.pass_past_in_kv) { - // past_key and past_value are used directly as key and value in attention computations - k = const_cast(data.past_key); - v = const_cast(data.past_value); - - // This path has a memory copy from past_key and past_value to present_key and present_value - // Avoid this path since the memory copy is unnecessary because past_key == present_key and - // past_value == present_value - int64_t k_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * qk_head_size; - int64_t v_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * v_head_size; - cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } else { - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, k, data.present_key)); - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, - batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, v, data.present_value)); - // Update pointers to present_k and present_v. - k = data.present_key; - v = data.present_value; - } - } - } else { // past_present_share_buffer - assert(qk_head_size == v_head_size); - assert(data.fused_cross_attention_kernel == nullptr); - assert(!use_fused_kernel); - assert(data.gemm_buffer != nullptr); - assert(!data.use_memory_efficient_attention); - assert(!data.use_flash_attention); - assert(data.has_qkv_workspace); - - if (nullptr != data.past_key || nullptr != data.present_key) { - // TODO: support this case. - ORT_THROW("buffer sharing for no bias case between past and present is not supported yet."); - } - - if (data.present != data.past) { - // For easy testing. Production should better avoid this path. - int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size; - cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } - - // append last k v to present - ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( - stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, max_threads_per_block, - use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer - data.gemm_buffer, data.present)); - - present_size_per_batch_k = parameters.max_sequence_length * qk_head_size; - present_size_per_batch_v = present_size_per_batch_k; - k = data.present; - v = data.present + batches * present_size_per_batch_k; - } - - // Q, K and V are ready now - DUMP_TENSOR_INIT(); - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - - // We only enable fused cross attention when there is no key padding mask. - // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(data.mask_index == nullptr); - - int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - scratch1); - - DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - - int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); - kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, - data.mask_index, batch_size, kv_sequence_length, stream, - kv_sequence_offset); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - - DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); - - FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = - reinterpret_cast(data.fused_cross_attention_kernel); - - // When there is no bias, we can directly use q and packed kv from inputs. - void const* query = q; - void const* packed_kv = k; - if (data.value == nullptr && data.bias == nullptr) { - query = data.query; - packed_kv = data.key; - } - - run_fused_cross_attention( - query, // Q - packed_kv, // packed KV - q_sequence_offset, // cumulated sequence length of Q - kv_sequence_offset, // cumulated sequence length of KV - data.output, // output - cross_attention_kernel, // kernels - batch_size, // batch size - num_heads, // number of heads - qk_head_size, // head size of Q/K/V - sequence_length, // sequence length of Q - kv_sequence_length, // sequence length of KV - stream); - - DUMP_TENSOR("trt cross output", data.output, batch_size, sequence_length, num_heads, v_head_size); - return Status::OK(); - } - - // Run TRT fused attention. - if (use_fused_kernel || use_fused_causal) { - int* sequence_offset = reinterpret_cast(scratch1); - if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { - DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); - LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); - } else { - sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - sequence_offset); - } - DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - - FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(fused_runner); - - const int S = use_fused_causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length); - - // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed. - const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size); - - fused_fp16_runner->setup(S, B); - - if (use_fused_kernel) { - assert(qkv_format == AttentionQkvFormat::QKV_BSN3H); - - // When there is no bias, we can directly use packed qkv from inputs. - void const* packed_qkv = qkv; - if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { - packed_qkv = data.query; - } - - fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); - DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, num_heads, v_head_size); - } else { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); - fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, num_heads, v_head_size); - } - return Status::OK(); - } - - // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) - : parameters.scale; - -#if USE_FLASH_ATTENTION - if (data.use_flash_attention) { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - assert(nullptr == data.mask_index); - assert(nullptr == data.relative_position_bias); - assert(parameters.head_size == parameters.v_head_size); - - void* query = reinterpret_cast(q); - void* key = reinterpret_cast(k); - void* value = reinterpret_cast(v); - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { - query = reinterpret_cast(const_cast(data.query)); - } - - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); - - constexpr bool is_causal = false; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(scratch1), - parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, is_causal)); - - DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, v_head_size); - - return Status::OK(); - } -#endif - -#if USE_MEMORY_EFFICIENT_ATTENTION - if (data.use_memory_efficient_attention) { - // We only enable fused cross attention when there is no key padding mask. - // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - - const void* query = q; - const void* key = k; - const void* value = v; - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { - assert(data.bias == nullptr); - query = data.query; - } - - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); - - MemoryEfficientAttentionParams p; - p.sm = device_prop.major * 10 + device_prop.minor; - p.is_half = sizeof(T) == 2; - p.batch_size = parameters.batch_size; - p.num_heads = parameters.num_heads; - p.sequence_length = parameters.sequence_length; - p.kv_sequence_length = parameters.total_sequence_length; - p.qk_head_size = parameters.head_size; - p.v_head_size = parameters.v_head_size; - p.causal = parameters.is_unidirectional; - p.scale = scale; - p.seqlen_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index)); - p.seqstart_q_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index + batch_size)); - p.seqstart_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index + 2 * batch_size + 1)); - p.query = query; - p.key = key; - p.value = value; - p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; - p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; - p.output = data.output; - p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) - ? scratch1 - : nullptr; - p.stream = stream; - run_memory_efficient_attention(p); - DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, v_head_size); - - return Status::OK(); - } -#endif - - // The following are unfused attention. - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH); const int* mask_index = data.mask_index; gsl::span& mask_index_dims = data.mask_index_dims; // Raw attention mask could be 2D (BxT) or 3D (BxSxT) or 4D(Bx1xMxM), where M is the max sequence length. bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2); - // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxT + // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch: BxNxSxT // Q: BxNxSxH, K (present_k): BxNxTxH, Q*K': BxNxSxT float one = 1.0f; float zero = 0.f; @@ -1077,22 +440,31 @@ Status QkvToContext( cublasSetStream(cublas, stream); - DUMP_TENSOR_D("q[BNSH]", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k[BNSH]", k, batch_size, num_heads, total_sequence_length, qk_head_size); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q[BNSH]", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k[BNSH]", data.k, batch_size, num_heads, total_sequence_length, qk_head_size); + + const int present_sequence_length = parameters.past_present_share_buffer + ? parameters.max_sequence_length + : total_sequence_length; + const int present_size_per_batch_k = present_sequence_length * qk_head_size; + const int present_size_per_batch_v = present_sequence_length * v_head_size; + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, total_sequence_length, sequence_length, qk_head_size, - &alpha, k, qk_head_size, present_size_per_batch_k, - q, qk_head_size, sequence_length * qk_head_size, - &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); + &alpha, data.k, qk_head_size, present_size_per_batch_k, + data.q, qk_head_size, sequence_length * qk_head_size, + &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); - DUMP_TENSOR_D("Q", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", k, batch_size, num_heads, qk_head_size, sequence_length); - DUMP_TENSOR_D("QK", scratch1, batch_size, num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); + constexpr size_t element_size = sizeof(T); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, total_sequence_length); - T* scratch2 = scratch1 + (bytes / element_size); + T* scratch2 = data.scratch + (bytes / element_size); // Apply softmax and store result R to scratch2: BxNxSxT if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask @@ -1102,14 +474,15 @@ Status QkvToContext( const TransformerOptions* options = TransformerOptions::GetInstance(); bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); - T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. + // replace Q*K' in place with masked score for persistent softmax. + T* persistent_softmax_workspace = data.scratch; ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask( ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias, - scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension, + data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, - mask_filter_value)); + parameters.mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index assert(mask_index_dims.size() == 1); // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. @@ -1117,277 +490,123 @@ Status QkvToContext( ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, data.relative_position_bias, parameters.broadcast_res_pos_bias, - scratch1, scratch2, parameters.is_unidirectional)); + data.scratch, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( ComputeSoftmax( stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, - parameters.broadcast_res_pos_bias, scratch1, scratch2, parameters.is_unidirectional)); + parameters.broadcast_res_pos_bias, data.scratch, scratch2, parameters.is_unidirectional)); } DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); - DUMP_TENSOR_D("V", v, batch_size, num_heads, sequence_length, v_head_size); + DUMP_TENSOR_D("V", data.v, batch_size, num_heads, sequence_length, v_head_size); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v - T* temp_output = qkv; + T* temp_output = data.q; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, v_head_size, sequence_length, total_sequence_length, - &one, v, v_head_size, present_size_per_batch_v, + &one, data.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, temp_output, data.output); + device_prop.maxThreadsPerBlock, false, temp_output, data.output); DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size); return result; } template -Status DecoderQkvToContext( +Status QkvToContext( const cudaDeviceProp& device_prop, - Stream* ort_stream, cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const T* gemm_query_buffer, - const T* gemm_kv_buffer, - const bool* key_padding_mask, - const T* key_cache, - const T* value_cache, - T* qkv_buffer, - T* workspace_buffer, - T* output, - T* new_key_cache, - T* new_value_cache) { + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); const int max_threads_per_block = device_prop.maxThreadsPerBlock; - const int BN = batch_size * num_heads; - const int BHN = BN * head_size; - const int BNS = BN * sequence_length; - const int k_buffer_offset = sequence_length * BHN; - const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; - T* temp_qkv_buffer = workspace_buffer; - auto stream = static_cast(ort_stream->GetHandle()); + // At most one fused kernel is enabled. + assert((int(data.use_flash_attention) + + int(data.use_memory_efficient_attention) + + int(fused_runner != nullptr) + + int(data.fused_cross_attention_kernel != nullptr)) <= 1); - const T* q = qkv_buffer; - // transpose q and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); - - const T* k = qkv_buffer + k_buffer_offset; - const T* v = qkv_buffer + v_buffer_offset; - if (!has_layer_state || !use_past) { - if (!static_kv) { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } else { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } - } else { - if (!static_kv) { - // transpose kv and copy them to temp_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); - // concat cache-k with k and copy to qkv_buffer - if (nullptr != key_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, - sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, - key_cache, - temp_qkv_buffer, - qkv_buffer + k_buffer_offset)); - } - // concat cache-v with v and copy to qkv_buffer - if (nullptr != value_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, - sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, - value_cache, - temp_qkv_buffer + k_buffer_offset, - qkv_buffer + v_buffer_offset)); - } + ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block)); + + if (!parameters.past_present_share_buffer) { + ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, + sequence_length, total_sequence_length, parameters.pass_past_in_kv, + stream, max_threads_per_block, data)); + + } else { // past_present_share_buffer + assert(qk_head_size == v_head_size); + assert(data.fused_cross_attention_kernel == nullptr); + assert(nullptr == fused_runner || parameters.is_unidirectional); + assert(data.gemm_buffer != nullptr); + assert(!data.use_memory_efficient_attention); + assert(!data.use_flash_attention); + assert(data.has_qkv_workspace); + + if (nullptr != data.past_key || nullptr != data.present_key) { + // TODO: support this case. + ORT_THROW("buffer sharing for no bias case between past and present is not supported yet."); } - } - if (has_layer_state) { - if (use_past && static_kv) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - } else { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); + if (data.present != data.past) { + // For easy testing. Production should better avoid this path. + int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size; + cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); } - } - // scratch1: BxNxSxL buffer - // scratch2: BxNxSxL buffer - // scratch3: BxNxSxH buffer - T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; - T* scratch2 = scratch1 + BNS * kv_sequence_length; - T* scratch3 = scratch2 + BNS * kv_sequence_length; - - // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL - // Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL - const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); - const int temp_matrix_size = sequence_length * kv_sequence_length; - float one = 1.0f; - float zero = 0.f; + // For fused causal, bias has been added to gemm_buffer. + const T* bias = (nullptr != fused_runner && parameters.is_unidirectional) ? nullptr : data.bias; - float alpha = rsqrt_head_size; - const int strideA = kv_sequence_length * head_size; - const int strideB = sequence_length * head_size; - if (use_past && static_kv) { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - kv_sequence_length, sequence_length, head_size, - &alpha, key_cache, head_size, strideA, - q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); - } else { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - kv_sequence_length, sequence_length, head_size, - &alpha, k, head_size, strideA, - q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + // append last k v to present + ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( + stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, max_threads_per_block, + bias, data.gemm_buffer, data.present)); + + data.k = data.present; + data.v = data.present + batch_size * num_heads * parameters.max_sequence_length * qk_head_size; } - constexpr bool is_unidirectional = false; - const T* add_before_softmax = nullptr; - if (has_key_padding_mask) { - constexpr int mask_dimension = 2; - constexpr int max_sequence_length = 0; - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( - ort_stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, key_padding_mask, add_before_softmax, - false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, - 1.0f, mask_dimension, max_sequence_length, false, nullptr, - mask_filter_value)); - } else { - ORT_RETURN_IF_ERROR(ComputeSoftmax( - stream, kv_sequence_length, sequence_length, batch_size, num_heads, - add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, - is_unidirectional)); + // Q, K and V are ready now + if (data.fused_cross_attention_kernel != nullptr) { + return FusedTrtCrossAttention(stream, parameters, data); } - // compute P*V (as V*P), and store in scratch3: BxNxSxH - if (use_past && static_kv) { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, - head_size, sequence_length, kv_sequence_length, - &one, value_cache, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); - } else { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, - head_size, sequence_length, kv_sequence_length, - &one, v, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + // Run TRT fused attention. + if (nullptr != fused_runner) { + return FusedTrtSelfAttention(stream, parameters, data); } - // scratch3 is BxNxSxH, transpose to output SxBxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, scratch3, output); -} + // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) + : parameters.scale; -Status LaunchDecoderAttentionKernel( - const cudaDeviceProp& device_prop, - Stream* stream, - cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const void* gemm_query_buffer, - const void* gemm_kv_buffer, - const bool* key_padding_mask, - const void* key_cache, - const void* value_cache, - void* qkv_buffer, - void* workspace_buffer, - void* output, - void* new_key_cache, - void* new_value_cache) { - if (element_size == 2) { - return DecoderQkvToContext( - device_prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } else { - return DecoderQkvToContext( - device_prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); + } +#endif + + return UnfusedAttention(device_prop, cublas, ort_stream, parameters, data, scale); } // Template Instantiation diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index af7373dd9fa1b..d0a5fb51a25d6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -2,11 +2,12 @@ // Licensed under the MIT License. #pragma once -#include "core/providers/cuda/shared_inc/cuda_utils.h" + #include #include -#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/common/gsl.h" #include "core/framework/allocator.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace contrib { @@ -49,39 +50,52 @@ size_t GetAttentionWorkspaceSize( template struct AttentionData { - T* gemm_buffer; - const T* bias; + T* gemm_buffer = nullptr; + const T* bias = nullptr; - const T* query; - const T* key; - const T* value; - const int* mask_index; + const T* query = nullptr; + const T* key = nullptr; + const T* value = nullptr; + const int* mask_index = nullptr; gsl::span mask_index_dims; - const T* past; - const T* past_key; - const T* past_value; - const T* relative_position_bias; - - bool has_qkv_workspace; - T* workspace; - T* temp_k_workspace; - T* temp_v_workspace; - - T* output; - T* present; - T* present_key; - T* present_value; - - void* fused_runner; - const void* fused_cross_attention_kernel; - - bool use_flash_attention; - bool use_memory_efficient_attention; - - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache; + const T* past = nullptr; + const T* past_key = nullptr; + const T* past_value = nullptr; + const T* relative_position_bias = nullptr; + + bool has_qkv_workspace = false; + T* workspace = nullptr; + T* temp_k_workspace = nullptr; + T* temp_v_workspace = nullptr; + + T* output = nullptr; + T* present = nullptr; + T* present_key = nullptr; + T* present_value = nullptr; + + void* fused_runner = nullptr; + const void* fused_cross_attention_kernel = nullptr; + + bool use_flash_attention = false; + bool use_memory_efficient_attention = false; + + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr; + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr; + + // Intermediate data + T* q = nullptr; + T* k = nullptr; + T* v = nullptr; + T* scratch = nullptr; + AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; }; +template +Status PrepareQkv(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block); + template Status QkvToContext( const cudaDeviceProp& device_prop, @@ -90,33 +104,6 @@ Status QkvToContext( contrib::AttentionParameters& parameters, AttentionData& data); -Status LaunchDecoderAttentionKernel( - const cudaDeviceProp& prop, // Device Properties - Stream* stream, // ORT Stream - cublasHandle_t& cublas, // Cublas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - // BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true) Status LaunchTransCtx(cudaStream_t stream, const int sequence_length, const int batch_size, const int head_size, const int num_heads, @@ -161,33 +148,32 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, const half* tensor_add, half* tensor_out); -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* past, - const float* k_v, - float* present); - -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* past, - const half* k_v, - half* present); +template +Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data); + +template +Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int past_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const T* biases, + const T* qkv_buffer, + T* present); template Status LaunchStridedCopy(cudaStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) T* out, longlong4 out_strides, // coord (b,n,s,h) int max_threads_per_block); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu new file mode 100644 index 0000000000000..89be0f1115f41 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -0,0 +1,466 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void ConcatTensorToTensor(const int tensor_add_sequence_length, + const T* tensor_in, + const T* tensor_add, + T* tensor_out) { + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + const int chunk_id = blockIdx.z; + + const int all_sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + // K: number of identical tensors + // tensor_in: K x BxNxPxH + // tensor_add: K x BxNxLxH + // tensor_out: K x BxNxTxH, where T = P + L + const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length; + + const int present_SH = all_sequence_length * H; + const int present_NSH = num_heads * present_SH; + int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); + if (s < tensor_in_sequence_length) { + const int past_SH = tensor_in_sequence_length * H; + const int past_NSH = num_heads * past_SH; + const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); + tensor_out[out_offset] = tensor_in[in_offset]; + } else if (s < all_sequence_length) { + const int SH = tensor_add_sequence_length * H; + const int NSH = num_heads * SH; + const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); + tensor_out[out_offset] = tensor_add[in_offset]; + } +} + +template +__global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length, + const int H, + const T* tensor_in, + const T* tensor_add, + T* tensor_out) { + // Use when (H*)*num_heads > 1024 + int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + const int chunk_id = blockIdx.z; + + const int all_sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int num_heads = blockDim.y; + const int stride = blockDim.x; + + // K: number of identical tensor + // tensor_in: K x BxNxPxH + // tensor_add: K x BxNxLxH + // tensor_out: K x BxNxTxH + const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length; + + const int present_SH = all_sequence_length * H; + const int present_NSH = num_heads * present_SH; + while (h < H) { + int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); + if (s < tensor_in_sequence_length) { + const int past_SH = tensor_in_sequence_length * H; + const int past_NSH = num_heads * past_SH; + const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); + tensor_out[out_offset] = tensor_in[in_offset]; + } else if (s < all_sequence_length) { + const int SH = tensor_add_sequence_length * H; + const int NSH = num_heads * SH; + const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); + tensor_out[out_offset] = tensor_add[in_offset]; + } + + h += stride; + } +} + +Status LaunchConcatTensorToTensor(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const float* tensor_in, + const float* tensor_add, + float* tensor_out) { + const dim3 grid(all_sequence_length, batch_size, matrix_num); + if (0 == (head_size & 1)) { + const int H = head_size / 2; + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + ConcatTensorToTensor<<>>(sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<<>>(sequence_length, + H, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } + } else { + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<<>>(sequence_length, + head_size, + tensor_in, + tensor_add, + tensor_out); + } + } + return CUDA_CALL(cudaGetLastError()); +} + +Status LaunchConcatTensorToTensor(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const half* tensor_in, + const half* tensor_add, + half* tensor_out) { + const dim3 grid(all_sequence_length, batch_size, matrix_num); + if (0 == (head_size % 4)) { + const int H = head_size / 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + ConcatTensorToTensor<<>>(sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<<>>(sequence_length, + H, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + ConcatTensorToTensor<<>>(sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<<>>(sequence_length, + H, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } + } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel. + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<<>>(sequence_length, + head_size, + tensor_in, + tensor_add, + tensor_out); + } + } + return CUDA_CALL(cudaGetLastError()); +} + +Status LaunchConcatPastToPresent(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const float* past, + const float* k_v, + float* present) { + return LaunchConcatTensorToTensor( + stream, + all_sequence_length, + sequence_length, + batch_size, + head_size, + num_heads, + max_threads_per_block, + 2, + past, + k_v, + present); +} + +Status LaunchConcatPastToPresent(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const half* past, + const half* k_v, + half* present) { + return LaunchConcatTensorToTensor( + stream, + all_sequence_length, + sequence_length, + batch_size, + head_size, + num_heads, + max_threads_per_block, + 2, + past, + k_v, + present); +} + +#ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP + +template +Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data) { + // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. + // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) + // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) + // When there is past state, the head size for Q/K/V shall be same: H == H_v. + + if (nullptr != data.present) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + + ORT_RETURN_IF_ERROR( + LaunchConcatPastToPresent( + stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, data.past, data.k, data.present)); + + // Update pointers to present_k and present_v. + data.k = data.present; + data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; + } else if (nullptr != data.past_key || nullptr != data.present_key) { + if (nullptr != data.past_key && nullptr == data.present_key) { + data.k = const_cast(data.past_key); + data.v = const_cast(data.past_value); + } else if (nullptr == data.past_key && nullptr != data.present_key) { + if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + data.k = data.present_key; + data.v = data.present_value; + } else { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + data.k = data.temp_k_workspace; + data.v = data.temp_v_workspace; + } + } else if (pass_past_in_kv) { + // past_key and past_value are used directly as key and value in attention computations + data.k = const_cast(data.past_key); + data.v = const_cast(data.past_value); + + // This path has a memory copy from past_key and past_value to present_key and present_value + // Avoid this path since the memory copy is unnecessary because past_key == present_key and + // past_value == present_value + int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; + int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; + cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } else { + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, + max_threads_per_block, 1, data.past_key, data.k, data.present_key)); + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, v_head_size, num_heads, + max_threads_per_block, 1, data.past_value, data.v, data.present_value)); + // Update pointers to present_k and present_v. + data.k = data.present_key; + data.v = data.present_value; + } + } + + return CUDA_CALL(cudaGetLastError()); +} + +// Template Instantiation +template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data); + +template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data); + +// ---------------------------------------------------------------------------------- +// Below kernels are for past and present sharing buffer +// ---------------------------------------------------------------------------------- + +template +__global__ void AddBiasTransAppendKvToPresentSmall( + const T* qkv, const T* biases, T* present, + const int head_size, const int past_sequence_length, const int max_sequence_length) { + // Input: BxSxMxNxH (Format 1) + // Output: (2, B, N, [P..P+S) of MaxS, H), + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + const int N = blockDim.y; + const int S = gridDim.x; + const int B = gridDim.y; + + constexpr int M = 3; // Matrix count in qkv + const int m = blockIdx.z + 1; // k = 1, v = 2 + + const int NH = N * head_size; + const int NHS = NH * S; + + qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); + if (biases) { + biases += (m * NH + n * head_size); + } + + const int MsH = max_sequence_length * head_size; + const int NMsH = N * MsH; + const int BNMsH = B * NMsH; + present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + + for (int h = threadIdx.x; h < head_size; h += blockDim.x) { + T bias = (biases ? biases[h] : (T)0.0f); + present[h] = qkv[h] + bias; + } +} + +template +__global__ void AddBiasTransAppendKvToPresent( + const T* qkv, const T* biases, T* present, + const int head_size, const int past_sequence_length, const int max_sequence_length) { + // Input: BxSxMxNxH (Format 1) + // Output: (2, B, N, [P..P+S) of MaxS, H), + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + const int n = blockIdx.x; + const int s = blockIdx.y; + const int b = (blockIdx.z >> 1); + const int N = gridDim.x; + const int S = gridDim.y; + const int B = (gridDim.z >> 1); + + constexpr int M = 3; // Matrix count in qkv + const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 + + const int NH = N * head_size; + const int NHS = NH * S; + + qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); + if (biases) { + biases += (m * NH + n * head_size); + } + + const int MsH = max_sequence_length * head_size; + const int NMsH = N * MsH; + const int BNMsH = B * NMsH; + present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + + for (int h = threadIdx.x; h < head_size; h += blockDim.x) { + T bias = (biases ? biases[h] : (T)0.0f); + present[h] = qkv[h] + bias; + } +} + +// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3. +// bias is of shape (3, NxH) or nullptr +// append to present of (2, B, N, (P..T) of M, H), +template +Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int past_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const T* biases, + const T* qkv_buffer, + T* present) { + assert(head_size <= (1 << 30)); + + int64_t nh = (int64_t)head_size * num_heads; + if (nh <= max_threads_per_block) { + const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + + AddBiasTransAppendKvToPresentSmall<<>>( + qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + } else { + const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v + const dim3 block(std::min(head_size, max_threads_per_block), 1, 1); + AddBiasTransAppendKvToPresent<<>>( + qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + } + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const float* bias, + const float* qkv_buffer, + float* present); + +template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const half* bias, + const half* qkv_buffer, + half* present); +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu new file mode 100644 index 0000000000000..5c65a30918ece --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -0,0 +1,492 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/add_bias_transpose.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + const bool past_present_share_buffer = parameters.past_present_share_buffer; + void* fused_runner = data.fused_runner; + bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); + + if (data.bias == nullptr) { + assert(nullptr == fused_runner); + // For quantized attention, bias has been added so only need transpose here. + // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH + assert(qk_head_size == v_head_size); + int matrix_to_trans = (past_present_share_buffer ? 1 : 3); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.gemm_buffer, qkv, 3)); + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } else { + // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) + // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) + // For unfused kernel, transpose to 3xBxNxSxH (format 1) + // For fused causal kernel, use format 1 since we need have K and V to update present state, + // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. + const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); + qkv_format = use_fused_kernel + ? AttentionQkvFormat::QKV_BSN3H + : (use_flash_or_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal + ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH + : AttentionQkvFormat::Q_K_V_BNSH)); + + // For fused causal, we will update gemm_buffer with bias directly. + T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; + + int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); + // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v + // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) + LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, + 3, parameters.do_rotary, parameters.past_sequence_length); + } + return Status::OK(); +} + +// For MultiHeadAttention with past state +template +Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + + if (data.bias == nullptr) { + // Below logic does not support fused attention with past without bias + // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. + + // cross attention with past state + if (data.past_key != nullptr && data.present_key == nullptr) { + assert(data.past_value != nullptr); + assert(data.query != nullptr); + assert(data.key == nullptr); + assert(data.value == nullptr); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + } + // cross attention with present state or self attention with present state + else if (data.past_key == nullptr && data.present_key != nullptr) { + assert(data.past_value == nullptr); + assert(data.present_value != nullptr); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + + // TODO: supporting packed qkv for self attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + + // TODO: supporting packed kv for cross attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.present_key)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, data.present_value)); + } + // self attention with past and present state + else { + assert(data.past_key != nullptr); + assert(data.past_value != nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + // TODO: supporting packed qkv for self attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, k)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, v)); + } + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value + else if ((data.use_memory_efficient_attention || data.use_flash_attention) && + data.past_key != nullptr && + data.past_value != nullptr && + parameters.pass_past_in_kv) { + // Transpose past_key and past_value to use memory efficient attention + + // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) + ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.past_key, data.temp_k_workspace)); + // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) + ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.past_value, data.temp_v_workspace)); + + // query => q, temp_k_workspace => k, temp_v_workspace => v + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + data.past_key = nullptr; + data.past_value = nullptr; + } + // When there is no past_key/past_value and there is present_key/present_value + // (e.g. get initial kv to use as past_kv in the next iteration) + else if ((data.use_memory_efficient_attention || data.use_flash_attention) && + data.present_key != nullptr && + data.present_value != nullptr) { + // Use memory efficient attention kernel + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); + + // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.temp_k_workspace, data.present_key)); + + // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.temp_v_workspace, data.present_value)); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } +#endif + else { + // Use unfused kernel for Q, use unfused kernel for K and V if needed + constexpr int format = 0; + // Query (BxSxNxH) => Q (BxNxSxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); + + if (!parameters.pass_past_in_kv) { + T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; + T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, k_dest, + true, -1); + + // Value (BxLxNxH_v) => V (BxNxLxH_v) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, + true, -1); + + DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); + } + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with packed QKV inputs +template +Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); + + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, qkv, + true, v_head_size, qkv_add_bias, 3); + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (!use_fused_kernel) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); + } + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with packed KV inputs +template +Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. + // CheckInputs verified this constraint. + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); + + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, k, + true, v_head_size, qkv_add_bias, 2); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (data.fused_cross_attention_kernel == nullptr) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); + } + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with Q, K and V inputs +template +Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); + + // gemm_buffer == nullptr and not packed + assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); + +#if DUMP_TENSOR_LEVEL > 1 + if (data.bias != nullptr) { + DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + } +#endif + + if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { + DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, + num_heads, sequence_length, kv_sequence_length); + } + + if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { + DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); + } + + if (data.fused_cross_attention_kernel != nullptr) { + assert(qk_head_size == v_head_size); + + // For fused cross attention, besides adding bias, K and V needed to be packed: + // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + else if (data.use_memory_efficient_attention || data.use_flash_attention) { + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, q, k, v); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } +#endif + else if (use_fused_kernel) { + assert(qk_head_size == v_head_size); + + // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); + DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); + + // Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, + true, -1); + + // Value (BxLxNxH_v) => K (BxNxLxH_v) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, + true, -1); + + DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +template +Status PrepareQkv(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + data.scratch = data.workspace; + if (data.has_qkv_workspace) { + const int size_per_batch_q = parameters.sequence_length * parameters.head_size; + const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size; + const int size_per_batch_v = parameters.kv_sequence_length * parameters.v_head_size; + const int batches = parameters.batch_size * parameters.num_heads; + const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); + const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); + const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); + data.q = data.workspace; + data.k = data.workspace + elements_q; + data.v = data.k + elements_k; + data.scratch = data.v + elements_v; + } + + if (nullptr != data.gemm_buffer) { // Attention operator + ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, + data.qkv_format)); + } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, + data.q, data.k, data.v, data.qkv_format)); + } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, + data.q, data.k, data.v, data.qkv_format)); + } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, + data.q, data.k, data.v, data.qkv_format)); + } else { // multihead attention operator, no past, separated Q/K/V inputs + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, + data.q, data.k, data.v, data.qkv_format)); + } + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + return Status::OK(); +} + +// Template Instantiation +template Status PrepareQkv( + contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block); + +template Status PrepareQkv( + contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index e7d2255fb46b8..01ea02f48d3ab 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -18,7 +18,6 @@ limitations under the License. */ #include -#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index f907d300607f9..3f703ae3d05e6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/decoder_attention.h" +#include "contrib_ops/cuda/bert/decoder_attention_impl.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" #include "core/framework/op_kernel.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" @@ -85,7 +85,8 @@ Status CheckInputs(const TensorShape& query_shape, } if (kv_weights_dims[0] != hidden_size || kv_weights_dims[1] != 2 * static_cast(hidden_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "kv_weights shall have shape (hidden size, 2 * hidden size)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_weights shall have shape (hidden size, 2 * hidden size)"); } const auto& bias_dims = bias_shape.GetDims(); @@ -137,7 +138,8 @@ Status CheckInputs(const TensorShape& query_shape, const auto& value_cache_dims = value_cache->Shape().GetDims(); if (value_cache_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value_cache' is expected to have 4 dimension, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' is expected to have 4 dimension, got ", value_cache_dims.size()); } @@ -353,10 +355,12 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { } } - size_t bytes = element_size * batch_size * (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; + size_t bytes = element_size * batch_size * + (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; auto qkv_buffer_p = GetScratchBuffer(bytes, context->GetComputeStream()); - bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (static_cast(2) * head_size + static_cast(kv_sequence_length)); + bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * + (static_cast(2) * head_size + static_cast(kv_sequence_length)); auto workspace_p = GetScratchBuffer(bytes, context->GetComputeStream()); Tensor* output(context->Output(0, query_shape)); diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu new file mode 100644 index 0000000000000..1dc22a9c8ea98 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/decoder_attention_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" + +using namespace onnxruntime::contrib::attention_softmax_cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status DecoderQkvToContext( + const cudaDeviceProp& device_prop, + Stream* ort_stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const T* gemm_query_buffer, + const T* gemm_kv_buffer, + const bool* key_padding_mask, + const T* key_cache, + const T* value_cache, + T* qkv_buffer, + T* workspace_buffer, + T* output, + T* new_key_cache, + T* new_value_cache) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int BN = batch_size * num_heads; + const int BHN = BN * head_size; + const int BNS = BN * sequence_length; + const int k_buffer_offset = sequence_length * BHN; + const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + + T* temp_qkv_buffer = workspace_buffer; + auto stream = static_cast(ort_stream->GetHandle()); + + const T* q = qkv_buffer; + // transpose q and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); + + const T* k = qkv_buffer + k_buffer_offset; + const T* v = qkv_buffer + v_buffer_offset; + if (!has_layer_state || !use_past) { + if (!static_kv) { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } else { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } + } else { + if (!static_kv) { + // transpose kv and copy them to temp_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); + // concat cache-k with k and copy to qkv_buffer + if (nullptr != key_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + key_cache, + temp_qkv_buffer, + qkv_buffer + k_buffer_offset)); + } + // concat cache-v with v and copy to qkv_buffer + if (nullptr != value_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + value_cache, + temp_qkv_buffer + k_buffer_offset, + qkv_buffer + v_buffer_offset)); + } + } + } + + if (has_layer_state) { + if (use_past && static_kv) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + } + } + + // scratch1: BxNxSxL buffer + // scratch2: BxNxSxL buffer + // scratch3: BxNxSxH buffer + T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; + T* scratch2 = scratch1 + BNS * kv_sequence_length; + T* scratch3 = scratch2 + BNS * kv_sequence_length; + + // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL + // Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL + const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); + const int temp_matrix_size = sequence_length * kv_sequence_length; + float one = 1.0f; + float zero = 0.f; + + float alpha = rsqrt_head_size; + const int strideA = kv_sequence_length * head_size; + const int strideB = sequence_length * head_size; + if (use_past && static_kv) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, key_cache, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + } else { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, k, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + } + + constexpr bool is_unidirectional = false; + const T* add_before_softmax = nullptr; + if (has_key_padding_mask) { + constexpr int mask_dimension = 2; + constexpr int max_sequence_length = 0; + ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( + ort_stream, kv_sequence_length, sequence_length, batch_size, + num_heads, nullptr, key_padding_mask, add_before_softmax, + false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, + 1.0f, mask_dimension, max_sequence_length, false, nullptr, + mask_filter_value)); + } else { + ORT_RETURN_IF_ERROR(ComputeSoftmax( + stream, kv_sequence_length, sequence_length, batch_size, num_heads, + add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, + is_unidirectional)); + } + + // compute P*V (as V*P), and store in scratch3: BxNxSxH + if (use_past && static_kv) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, value_cache, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, device_prop)); + } else { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, v, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, device_prop)); + } + + // scratch3 is BxNxSxH, transpose to output SxBxNxH + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, scratch3, output); +} + +Status LaunchDecoderAttentionKernel( + const cudaDeviceProp& device_prop, + Stream* stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const void* gemm_query_buffer, + const void* gemm_kv_buffer, + const bool* key_padding_mask, + const void* key_cache, + const void* value_cache, + void* qkv_buffer, + void* workspace_buffer, + void* output, + void* new_key_cache, + void* new_value_cache) { + if (element_size == 2) { + return DecoderQkvToContext( + device_prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } else { + return DecoderQkvToContext( + device_prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h new file mode 100644 index 0000000000000..9db9ccb45e330 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/bert/attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +Status LaunchDecoderAttentionKernel( + const cudaDeviceProp& prop, // Device Properties + Stream* stream, // ORT Stream + cublasHandle_t& cublas, // Cublas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 8f1252f863ef6..25f3f59165e43 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -263,14 +263,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; - data.gemm_buffer = nullptr; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); data.key = (nullptr == key || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(key->Data()); data.value = (nullptr == value || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); - data.past = nullptr; data.past_key = pass_key_value_as_past ? reinterpret_cast(key->Data()) : (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); @@ -283,7 +281,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.temp_k_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_k_work_space.get()) : nullptr; data.temp_v_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_v_work_space.get()) : nullptr; data.output = reinterpret_cast(output->MutableData()); - data.present = nullptr; data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index b0556512de0b7..705f2d49fe2bf 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -195,28 +195,21 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); - data.bias = nullptr; // bias has been added - data.query = nullptr; - data.key = nullptr; - data.value = nullptr; - data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); - data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); - data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast(past_tensor->Data()); - data.past_key = nullptr; - data.past_value = nullptr; - data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention + if (nullptr != mask_index) { + data.mask_index = mask_index->Data(); + data.mask_index_dims = mask_index->Shape().GetDims(); + } + + if (nullptr != past_tensor) { + data.past = reinterpret_cast(past_tensor->Data()); + } + data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); - data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); - data.present_key = nullptr; - data.present_value = nullptr; - data.fused_runner = fused_runner; - data.fused_cross_attention_kernel = nullptr; - data.use_flash_attention = use_flash_attention; - data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = nullptr; - data.cumulated_sequence_length_kv_cache = nullptr; + if (nullptr != present) { + data.present = reinterpret_cast(present->MutableData()); + } return QkvToContext(GetDeviceProp(), cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index 9a150c9e6cd77..b0ed3ff82226a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -30,6 +30,7 @@ limitations under the License. #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/rocm/bert/attention_impl.h" #include "contrib_ops/rocm/bert/attention_softmax.h" +#include "contrib_ops/rocm/bert/decoder_attention_impl.h" using namespace onnxruntime::rocm; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 19b2bc34efaec..3164e8c211099 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -28,34 +28,6 @@ size_t GetAttentionWorkspaceSize( int sequence_length, int past_sequence_length); -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - RocmTuningContext* tuning_ctx, // context for tuning - Stream* stream, // ORT Stream - rocblas_handle& rocblas, // Rocblas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden layer size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - Status LaunchTransCtx(hipStream_t stream, const int sequence_length, const int batch_size, const int head_size, const int num_heads, const int max_threads_per_block, const bool reversed_bs, const float* input, float* output); diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h new file mode 100644 index 0000000000000..d71c6d8440a44 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +Status LaunchDecoderAttentionKernel( + const hipDeviceProp_t& prop, // Device Properties + RocmTuningContext* tuning_ctx, // context for tuning + Stream* stream, // ORT Stream + rocblas_handle& rocblas, // Rocblas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden layer size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index b7d26d87f2705..f0e5fbbd38721 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1030,7 +1030,10 @@ Status SessionState::CreateSubgraphSessionState() { for (auto& node : graph_.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { const auto& ep = node.GetExecutionProviderType(); - if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && ep != kRocmExecutionProvider && ep != kDmlExecutionProvider) { + if (!ep.empty() && + ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && + ep != kRocmExecutionProvider && ep != kDmlExecutionProvider && + ep != kJsExecutionProvider) { // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow // node containing the subgraph it will create whatever state it needs internally. continue; diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index aa2ad9f1ff6b1..4313fae767fe5 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -136,8 +136,9 @@ Performs element-wise binary {name} on 8 bit data types (with Numpy-style broadc static const char* QuantizeLinear_ver1_doc = R"DOC( The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor. -The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8. -For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. +The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, [-128, 127] if it's int8, +[0, 65,535] if it's uint16, and [-32,768, 32,767] if it's int16. For (x / y_scale), it's rounding to nearest ties to even. +Refer to https://en.wikipedia.org/wiki/Rounding for details. Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -161,8 +162,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T2", OpSchema::Optional) .Output(0, "y", "N-D quantized output tensor. It has same shape as input 'x'.", "T2") .TypeConstraint("T1", {"tensor(float16)", "tensor(float)"}, "Constrain 'x', 'y_scale' to float tensors.") - .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, - "Constrain 'y_zero_point' and 'y' to 8-bit integer tensors.") + .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", "tensor(uint16)"}, + "Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.") .SetDoc(QuantizeLinear_ver1_doc) .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { if (ctx.getNumInputs() == 3 && ctx.getInputType(2) != nullptr) { @@ -202,9 +203,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(DequantizeLinear, 1, "T1", OpSchema::Optional) .Output(0, "y", "N-D full precision output tensor. It has same shape as input 'x'.", "T2") - .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)", "tensor(int32)"}, - "Constrain 'x' and 'x_zero_point' to 8-bit integer tensors or 32-bit " - "signed integer tensors.") + .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", + "tensor(uint16)", "tensor(int32)"}, + "Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, " + "16-bit integer tensors, or 32-bit signed integer tensors.") .TypeConstraint("T2", {"tensor(float16)", "tensor(float)"}, "Constrain 'y', 'x_scale' to float tensors.") .SetDoc(DequantizeLinear_ver1_doc) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index d4164681f2bba..383c1d689d3c3 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4145,8 +4145,6 @@ Status Graph::InlineFunction(Node& callnode) { // std::cout << "Graph after inlining\n\n" << *this << std::endl << std::flush; - ORT_RETURN_IF_ERROR(this->Resolve()); - return Status::OK(); } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index f517be185b3fa..b6ac4a1ca1d6c 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -633,6 +633,24 @@ void int8_t ZeroPoint ); +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_U16_KERNEL)( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_S16_KERNEL)( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint); + template struct MLAS_QUANT_KERNEL { @@ -749,6 +767,8 @@ extern "C" { MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8Kernel; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8Kernel; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8Kernel; + MLAS_QUANTIZE_LINEAR_S16_KERNEL MlasQuantizeLinearS16Kernel; + MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel; #if defined(MLAS_TARGET_AMD64) MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; @@ -959,6 +979,8 @@ struct MLAS_PLATFORM { const MLAS_GEMM_QUANT_DISPATCH* GemmU8X8Dispatch; MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; + MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; + MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; #endif #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; @@ -986,6 +1008,8 @@ struct MLAS_PLATFORM { MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL* ReduceMinimumMaximumF32Kernel; MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; + MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; + MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 86b7450a7c4e5..7e2b117d6f249 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -230,6 +230,8 @@ Return Value: this->QLinearAddU8Kernel = MlasQLinearAddU8Kernel; this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel; this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; + this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; + this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->NchwcBlockSize = 8; this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; @@ -475,6 +477,8 @@ Return Value: this->GemmDoubleKernel = MlasDgemmKernel; this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel; this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; + this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; + this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; #if defined(__linux__) unsigned long hwcap2 = getauxval(AT_HWCAP2); diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp index 0d38288c6d42c..830a3a6a492db 100644 --- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp +++ b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp @@ -1,3 +1,4 @@ +#include #include "mlasi.h" #include @@ -82,8 +83,15 @@ Return Value: auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); - auto CharVector = vec_pack(ShortVector0, ShortVector1); - vec_xst(CharVector, 0, (int8_t *) Output); + + if constexpr (std::is_same_v || std::is_same_v) { + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, Output); + } else { + static_assert(std::is_same_v || std::is_same_v); + vec_xst(ShortVector0, 0, Output); + vec_xst(ShortVector1, 0, &Output[8]); + } Output += 16; Input += 16; @@ -124,3 +132,30 @@ MlasQuantizeLinearS8Kernel( { MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); } + +void +MLASCALL +MlasQuantizeLinearU16Kernel( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS16Kernel( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index c6e8af38c0020..133ad79594c55 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -21,6 +21,7 @@ Module Name: #include "mlasi.h" #if defined(MLAS_NEON64_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) +#include // // QuantizeLinear implementation using NEON or SSE2 intrinsics. @@ -79,6 +80,20 @@ MlasQuantizeLinearPackBytes( MLAS_INT32X4 IntegerVector ); +template +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ); + +template +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ); + #if defined(MLAS_NEON64_INTRINSICS) template @@ -100,6 +115,104 @@ MlasQuantizeLinearPackBytes( return vreinterpretq_s32_u8(ByteVector); } +template<> +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + // + // Swizzle the least significant u16 from each int32_t element to the + // bottom eight bytes of the vector register. + // + + uint16x8_t WordVector = vreinterpretq_u16_s32(IntegerVector); + WordVector = vuzp1q_u16(WordVector, WordVector); + return vreinterpretq_s32_u16(WordVector); +} + +template<> +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + // + // Swizzle the least significant u16 from each int32_t element to the + // bottom eight bytes of the vector register. + // + + int16x8_t WordVector = vreinterpretq_s16_s32(IntegerVector); + WordVector = vuzp1q_s16(WordVector, WordVector); + return vreinterpretq_s32_s16(WordVector); +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + // Copies the lower 4 packed elements of the vector into memory (Output). + + if constexpr (std::is_same_v || std::is_same_v) { + vst1q_lane_s32(reinterpret_cast(Output), IntegerVector, 0); + } else { + static_assert(std::is_same_v || std::is_same_v); + vst1q_lane_s64(reinterpret_cast(Output), vreinterpretq_s64_s32(IntegerVector), 0); + } +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + uint8_t* Output + ) +{ + // Copies the lower 8-bit element of the vector into memory (Output). + vst1q_lane_u8(Output, vreinterpretq_u8_s32(IntegerVector), 0); +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + int8_t* Output + ) +{ + // Copies the lower 8-bit element of the vector into memory (Output). + vst1q_lane_s8(Output, vreinterpretq_s8_s32(IntegerVector), 0); +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + uint16_t* Output + ) +{ + // Copies the lower 16-bit element of the vector into memory (Output). + vst1q_lane_u16(Output, vreinterpretq_u16_s32(IntegerVector), 0); +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + int16_t* Output + ) +{ + // Copies the lower 16-bit element of the vector into memory (Output). + vst1q_lane_s16(Output, vreinterpretq_s16_s32(IntegerVector), 0); +} + #else template<> @@ -128,6 +241,86 @@ MlasQuantizeLinearPackBytes( return IntegerVector; } +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ +#if defined(MLAS_SSE41_INTRINSICS) + IntegerVector = _mm_packus_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. +#else + // Cannot use _mm_packus_epi32 because that was not available until SSE4.1. + // Instead, emulate by sign-extending the first 16-bits of each packed 32-bit element. + // Afterwards, can use _mm_packs_epi32, which is available on SSE2. + // See: https://stackoverflow.com/a/11028244 + + IntegerVector = _mm_slli_epi32(IntegerVector, 16); + IntegerVector = _mm_srai_epi32(IntegerVector, 16); // Sign-extend: undo left shift with right arithmetic shift + IntegerVector = _mm_packs_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. +#endif // defined(MLAS_SSE41_INTRINSICS) + + return IntegerVector; +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + IntegerVector = _mm_packs_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. + + return IntegerVector; +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + // Copies the lower 4 packed elements of the vector into memory (Output). + + if constexpr (std::is_same_v || std::is_same_v) { + *(reinterpret_cast(Output)) = _mm_cvtsi128_si32(IntegerVector); + } else { + static_assert(std::is_same_v || std::is_same_v); + +#if defined(MLAS_TARGET_IX86) + // x86 does not support _mm_cvtsi128_si64, so use _mm_maskmoveu_si128 instead. + constexpr uint32_t bytes_high_bit = 0x80808080; + const __m128i first_8_bytes_mask = _mm_set_epi32(0, 0, bytes_high_bit, bytes_high_bit); + _mm_maskmoveu_si128(IntegerVector, first_8_bytes_mask, reinterpret_cast(Output)); +#else + *(reinterpret_cast(Output)) = _mm_cvtsi128_si64(IntegerVector); +#endif // defined(MLAS_TARGET_IX86) + } +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v); + + // Copies the lower element of the vector into memory (Output). + // Expects that the 32-bit element in lane 0 is already within the valid numerical + // range of the OutputType. + *Output = static_cast(_mm_cvtsi128_si32(IntegerVector)); +} + #endif template @@ -180,12 +373,7 @@ Return Value: MinimumValueVector, MaximumValueVector, ZeroPointVector); IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); - -#if defined(MLAS_NEON64_INTRINSICS) - vst1q_lane_s32((int32_t*)Output, IntegerVector, 0); -#else - *((int32_t*)Output) = _mm_cvtsi128_si32(IntegerVector); -#endif + MlasQuantizeLinearStore4PackedValues(IntegerVector, Output); Input += 4; Output += 4; @@ -202,11 +390,7 @@ Return Value: auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, MinimumValueVector, MaximumValueVector, ZeroPointVector); -#if defined(MLAS_NEON64_INTRINSICS) - vst1q_lane_u8((uint8_t*)Output + n, vreinterpretq_u8_s32(IntegerVector), 0); -#else - *((uint8_t*)Output + n) = (uint8_t)_mm_cvtsi128_si32(IntegerVector); -#endif + MlasQuantizeLinearStoreSingleValue(IntegerVector, &Output[n]); } } @@ -236,6 +420,32 @@ MlasQuantizeLinearU8Kernel( MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); } +void +MLASCALL +MlasQuantizeLinearU16Kernel( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint +) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS16Kernel( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint +) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + template<> void MLASCALL @@ -274,6 +484,44 @@ MlasQuantizeLinear( Input, Output, N, Scale, ZeroPoint); } +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearU16Kernel( +#else + MlasQuantizeLinearU16Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearS16Kernel( +#else + MlasQuantizeLinearS16Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + #else #if defined(MLAS_TARGET_POWER) @@ -306,6 +554,34 @@ MlasQuantizeLinear( GetMlasPlatform().QuantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); } +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ) +{ + GetMlasPlatform().QuantizeLinearS16Kernel(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ) +{ + GetMlasPlatform().QuantizeLinearU16Kernel(Input, Output, N, Scale, ZeroPoint); +} + #endif // @@ -381,6 +657,29 @@ MlasQuantizeLinear( float Scale, uint8_t ZeroPoint ); + +template +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ); + +template +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ); + #endif #endif diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc index b67f6d6ec0794..624679e7b1b4b 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc @@ -1,131 +1,37 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "core/optimizer/double_qdq_pairs_remover.h" +#include #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" namespace onnxruntime { -Status DoubleQDQPairsRemover::ApplyImpl( - Graph& graph, - bool& modified, - int /*graph_level*/, - const logging::Logger& /*logger*/) const { - const GraphViewer graph_viewer(graph); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - - for (const auto& self_index : node_topology_list) { - NodeIndex parent_index = 0; - NodeIndex child_index = 0; - NodeIndex grandchild_index = 0; - if (IsNodeRemovable(graph, self_index, parent_index, child_index, grandchild_index)) { - graph.RemoveEdge(parent_index, self_index, 0, 0); - graph.RemoveEdge(self_index, child_index, 0, 0); - graph.RemoveEdge(child_index, grandchild_index, 0, 0); - graph_utils::ReplaceNodeInput(*graph.GetNode(grandchild_index), 0, *graph.GetNode(self_index)->MutableInputDefs()[0]); - graph.AddEdge(parent_index, grandchild_index, 0, 0); - graph.RemoveNode(child_index); - graph.RemoveNode(self_index); - modified = true; - } - } - return Status::OK(); -} - -bool DoubleQDQPairsRemover::IsNodeRemovable( - Graph& graph, - const NodeIndex& self_index, - NodeIndex& parent_index, - NodeIndex& child_index, - NodeIndex& grandchild_index) { - // Check if the self is a DQ, and have one parent and one child, and cannot be a graph output - Node* self = graph.GetNode(self_index); - if (self == nullptr || - self->OpType() != "DequantizeLinear" || - self->GetInputEdgesCount() != 1 || - self->GetOutputEdgesCount() != 1 || - self->InputDefs().size() != InputIndex::TOTAL_COUNT || - graph.NodeProducesGraphOutput(*self)) { - return false; - } - - // Type is either "tensor(uint8)" or "tensor(int8)" - const auto& self_zp_type = *self->InputDefs()[InputIndex::ZERO_POINT_ID]->Type(); - // child should be a Q, and have only one child, have the same type as self, and cannot be a graph output - child_index = self->OutputEdgesBegin()->GetNode().Index(); - const Node* child = graph.GetNode(child_index); - if (child == nullptr || - child->OpType() != "QuantizeLinear" || - child->GetOutputEdgesCount() != 1 || - child->InputDefs().size() != InputIndex::TOTAL_COUNT || - *child->InputDefs()[InputIndex::ZERO_POINT_ID]->Type() != self_zp_type || - graph.NodeProducesGraphOutput(*child)) { - return false; - } - - // parent should be a Q, and have only one output, and cannot be a graph output - parent_index = self->InputEdgesBegin()->GetNode().Index(); - Node* parent = graph.GetNode(parent_index); - if (parent == nullptr || - parent->GetOutputEdgesCount() != 1 || - parent->OpType() != "QuantizeLinear" || - graph.NodeProducesGraphOutput(*parent)) { - return false; - } - - // grandchild should be a DQ - grandchild_index = child->OutputEdgesBegin()->GetNode().Index(); - Node* grandchild = graph.GetNode(grandchild_index); - if (grandchild == nullptr || - grandchild->OpType() != "DequantizeLinear") { - return false; - } - const auto get_constant_initializer = [&graph](const std::string& initializer_name) { - return graph.GetConstantInitializer(initializer_name, true); - }; - if (!QDQ::IsQDQPairSupported(*parent, *self, get_constant_initializer, graph.ModelPath()) || - !QDQ::IsQDQPairSupported(*child, *grandchild, get_constant_initializer, graph.ModelPath())) { - return false; - } - bool skip_reset = false; - float new_scale = 0.0f; - if (self_zp_type == "tensor(uint8)") { - uint8_t new_zero_point = 0; - if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point, skip_reset)) { - return false; - } - if (skip_reset) { - return true; - } - ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale); - ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale); - ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point); - ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point); - } else { - int8_t new_zero_point = 0; - if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point, skip_reset)) { - return false; - } - if (skip_reset) { - return true; - } - ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale); - ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale); - ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point); - ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point); - } - return true; +// Applies a new zero point or scale as the input for a Q/DQ node. +template +static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index, T value) { + const auto* input_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[index]->Name()); + Initializer input_init{*input_tensor, graph.ModelPath()}; + ONNX_NAMESPACE::TensorProto new_input_tensor(*input_tensor); + input_init.data()[0] = value; + input_init.ToProto(new_input_tensor); + auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name()); + new_input_tensor.set_name(new_name); + NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor); + graph_utils::ReplaceNodeInput(node, index, new_input); } +// Returns a new zero point and scale value for the given Q/DQ nodes. template -bool DoubleQDQPairsRemover::FindNewZeroPointAndScale(const Graph& graph, const Node& node1, const Node& node2, - float& new_scale, T& new_zero_point, bool& skip_reset) { +static bool FindNewZeroPointAndScale(const Graph& graph, const Node& node1, const Node& node2, + float& new_scale, T& new_zero_point, bool& skip_reset) { // scale & zero point share same initializer, no need to reset the value - const std::string& node1_scale_name = node1.InputDefs()[InputIndex::SCALE_ID]->Name(); - const std::string& node2_scale_name = node2.InputDefs()[InputIndex::SCALE_ID]->Name(); - const std::string& node1_zp_name = node1.InputDefs()[InputIndex::ZERO_POINT_ID]->Name(); - const std::string& node2_zp_name = node2.InputDefs()[InputIndex::ZERO_POINT_ID]->Name(); + const std::string& node1_scale_name = node1.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); + const std::string& node2_scale_name = node2.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); + const std::string& node1_zp_name = node1.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); + const std::string& node2_zp_name = node2.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); skip_reset = false; if (node1_scale_name == node2_scale_name && node1_zp_name == node2_zp_name) { skip_reset = true; @@ -175,16 +81,141 @@ bool DoubleQDQPairsRemover::FindNewZeroPointAndScale(const Graph& graph, const N return true; } -template -void DoubleQDQPairsRemover::ApplyNewInputValue(Graph& graph, Node& node, const InputIndex& index, T value) { - const auto* input_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[index]->Name()); - Initializer input_init{*input_tensor, graph.ModelPath()}; - TensorProto new_input_tensor(*input_tensor); - input_init.data()[0] = value; - input_init.ToProto(new_input_tensor); - auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name()); - new_input_tensor.set_name(new_name); - NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor); - graph_utils::ReplaceNodeInput(node, index, new_input); +// Recomputes the zero point and scale of the outer Q/DQ nodes (i.e., Q1 and DQ2). This is necessary because +// the original two QDQ pairs may have different zero-points and scales. Ex: Q1 -> DQ1 -> Q2 -> DQ2, where +// the first pair has (zp1, scale1) and the second pair has (zp2, scale2). +// After removing the middle two nodes, the zero point and scale of the final (outer) ops must be recomputed +// for correctness. +template +static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2, Node& dq2) { + bool skip_reset = false; + float new_scale = 0.0f; + ZeroPointType new_zero_point = 0; + if (!FindNewZeroPointAndScale(graph, dq1, q2, new_scale, new_zero_point, skip_reset)) { + return false; + } + if (skip_reset) { + return true; + } + ApplyNewInputValue(graph, dq2, QDQ::InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, q1, QDQ::InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, dq2, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point); + ApplyNewInputValue(graph, q1, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point); + + return true; +} + +// Checks if the provided node index (dq1_index) is a part of a valid double QDQ pair sequence +// (i.e., Q1 -> DQ1 -> Q2 -> DQ2) that can be reduced to the outer Q/DQ nodes (i.e., Q1 -> DQ2). +// If so, the zero point and scale of the outer Q/DQ nodes are recomputed and the node indices of the other nodes +// in the sequence (i.e., Q1, Q2, and DQ2) are returned via output parameters. +static bool IsReducibleDoubleQDQSequence(Graph& graph, NodeIndex& q1_index, NodeIndex dq1_index, + NodeIndex& q2_index, NodeIndex& dq2_index) { + // Ensure that dq1 is a DQ operator, has one parent and one child, and is not a graph output + Node* dq1 = graph.GetNode(dq1_index); + if (dq1 == nullptr || + dq1->OpType() != "DequantizeLinear" || + dq1->GetInputEdgesCount() != 1 || + dq1->GetOutputEdgesCount() != 1 || + graph.NodeProducesGraphOutput(*dq1)) { + return false; + } + + // Ensure that q2 is a Q operator, has only one child, and is not a graph output + q2_index = dq1->OutputEdgesBegin()->GetNode().Index(); + const Node* q2 = graph.GetNode(q2_index); + if (q2 == nullptr || + q2->OpType() != "QuantizeLinear" || + q2->GetOutputEdgesCount() != 1 || + graph.NodeProducesGraphOutput(*q2)) { + return false; + } + + // Ensure that q1 is a Q operator, has only one output, and is not a graph output + q1_index = dq1->InputEdgesBegin()->GetNode().Index(); + Node* q1 = graph.GetNode(q1_index); + if (q1 == nullptr || + q1->GetOutputEdgesCount() != 1 || + q1->OpType() != "QuantizeLinear" || + graph.NodeProducesGraphOutput(*q1)) { + return false; + } + + // Ensure the dq2 is a DQ operator. + dq2_index = q2->OutputEdgesBegin()->GetNode().Index(); + Node* dq2 = graph.GetNode(dq2_index); + if (dq2 == nullptr || + dq2->OpType() != "DequantizeLinear") { + return false; + } + + const auto get_constant_initializer = [&graph](const std::string& initializer_name) { + return graph.GetConstantInitializer(initializer_name, true); + }; + + // Each QDQ pair (i.e., q1 -> dq1, q2 -> dq2) has to meet the following additional requirements: + // - Scalar/constant zero-point and scale. + // - The DQ and Q ops within a pair must have the same scale and zero-point. + // However, each pair is allowed to have different scales and zero-points. + // + // TODO: IsQDQPairSupported() requires an explicit zero-point input, but technically a default + // value of 0 could be fine. + if (!QDQ::IsQDQPairSupported(*q1, *dq1, get_constant_initializer, graph.ModelPath()) || + !QDQ::IsQDQPairSupported(*q2, *dq2, get_constant_initializer, graph.ModelPath())) { + return false; + } + + const auto& dq1_input_defs = dq1->InputDefs(); + const ONNX_NAMESPACE::TensorProto* dq1_zp_tensor_proto = graph.GetConstantInitializer( + dq1_input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Name(), true); + + assert(dq1_zp_tensor_proto != nullptr); // IsQDQPairSupported should have checked that this exists. + + auto dq1_zp_type = dq1_zp_tensor_proto->data_type(); + + if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + } + + if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) { + return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + } + + if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) { + return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + } + + if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) { + return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + } + + return false; // Unsupported zero-point type +} + +Status DoubleQDQPairsRemover::ApplyImpl( + Graph& graph, + bool& modified, + int /*graph_level*/, + const logging::Logger& /*logger*/) const { + const GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (const auto& dq1_index : node_topology_list) { + NodeIndex q1_index = 0; + NodeIndex q2_index = 0; + NodeIndex dq2_index = 0; + if (IsReducibleDoubleQDQSequence(graph, q1_index, dq1_index, q2_index, dq2_index)) { + graph.RemoveEdge(q1_index, dq1_index, 0, 0); + graph.RemoveEdge(dq1_index, q2_index, 0, 0); + graph.RemoveEdge(q2_index, dq2_index, 0, 0); + graph_utils::ReplaceNodeInput(*graph.GetNode(dq2_index), 0, *graph.GetNode(dq1_index)->MutableInputDefs()[0]); + graph.AddEdge(q1_index, dq2_index, 0, 0); + graph.RemoveNode(q2_index); + graph.RemoveNode(dq1_index); + modified = true; + } + } + return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.h b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h index c016f7181b7fe..1833b007674fd 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.h +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h @@ -3,19 +3,16 @@ #pragma once -#include "core/common/common.h" #include "core/optimizer/graph_transformer.h" -#include "core/optimizer/qdq_transformer/qdq_util.h" namespace onnxruntime { -using ONNX_NAMESPACE::TensorProto; -using ONNX_NAMESPACE::TensorProto_DataType; -using QDQ::InputIndex; - /** * @Class DoubleQDQPairsRemover * @brief Remove one pair of Q-DQ from Double Q-DQ pairs. + * Specifically, this transformer converts the sequence Q1 -> DQ1 -> Q2 -> DQ2, where the first pair has (zp1, scale1) + * and the second pair has (zp2, scale2), into the sequence Q1 -> DQ2 by removing the middle two nodes. The zero-point + * and scale of the final QDQ pair is recomputed to preserve equality to the original sequence. */ class DoubleQDQPairsRemover : public GraphTransformer { public: @@ -27,28 +24,5 @@ class DoubleQDQPairsRemover : public GraphTransformer { bool& modified, int graph_level, const logging::Logger& logger) const override; - - static bool IsNodeRemovable( - Graph& graph, - const NodeIndex& self_index, - NodeIndex& parent_index, - NodeIndex& child_index, - NodeIndex& grandchild_index); - - template - static bool FindNewZeroPointAndScale( - const Graph& graph, - const Node& node1, - const Node& node2, - float& new_scale, - T& new_zero_point, - bool& skip_reset); - - template - static void ApplyNewInputValue( - Graph& graph, - Node& node, - const InputIndex& index, - T value); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index d7039cb4b7cfc..0e383c3031ca6 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include #include "core/mlas/inc/mlas.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" @@ -32,7 +33,8 @@ void SplitQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { // create rules for ops that don't change the data void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { // 3 nodes. DQ, target, Q. Merge into target and remove DQ and Q. - const std::string action_name{"drop"}; + const std::string drop_action_name{"drop"}; + const std::string drop_action_no_int16_name{"drop_no_int16_support"}; NTO::NodeLocation dq{NTO::NodeType::kInput, 0}; NTO::NodeLocation q{NTO::NodeType::kOutput, 0}; @@ -42,22 +44,33 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { MoveToSlot(dq, ArgType::kInput, 0, ArgType::kInput, 0), MoveToSlot(q, ArgType::kOutput, 0, ArgType::kOutput, 0)}; - std::unique_ptr action = std::make_unique(std::move(moves)); + std::unique_ptr drop_action_no_int16 = std::make_unique( + std::vector(moves)); // Copy before std::move(moves) + std::unique_ptr drop_action = std::make_unique(std::move(moves)); #if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); - qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + // Use a separate selector + action that disallows 16-bit types for MaxPool and Resize. + // int16 MaxPool is not supported by the ONNX specification. + // int16 Resize is not supported by the ORT implementation (although allowed by ONNX). + std::unique_ptr selector_disallow_16bit = std::make_unique(false); + qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name, + {{"MaxPool", {12}}, + {"Resize", {}}}, + std::move(selector_disallow_16bit), + std::move(drop_action_no_int16)); + + std::unique_ptr selector = std::make_unique(true); + qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_name, {{"Gather", {}}, {"Reshape", {}}, {"Transpose", {}}, - {"MaxPool", {12}}, - {"Resize", {}}, {"Squeeze", {}}, {"Unsqueeze", {}}}, std::move(selector), - std::move(action)); + std::move(drop_action)); #else - qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); + qdq_selector_action_registry.RegisterAction(drop_action_no_int16_name, std::move(drop_action_no_int16)); + qdq_selector_action_registry.RegisterAction(drop_action_name, std::move(drop_action)); #endif } @@ -74,6 +87,7 @@ void DropDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(std::move(moves)); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when ArgMax supports 16-bit integer input tensors. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"ArgMax", {}}}, @@ -91,6 +105,7 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when unary QLinear* ops support 16-bit. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"AveragePool", {}}, @@ -112,6 +127,7 @@ void BinaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when binary QLinear* ops support 16-bit. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Add", {}}, @@ -131,6 +147,7 @@ void VariadicOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when QLinearConcat supports 16-bit. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, @@ -152,6 +169,7 @@ void ConvQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool is_ std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when QLinearConv supports 16-bit. std::unique_ptr selector = std::make_unique(is_int8_allowed); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, @@ -174,6 +192,7 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when QLinearMatMul and MatMulInteger support 16-bit. std::unique_ptr selector = std::make_unique(is_int8_allowed); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"MatMul", {}}}, @@ -195,6 +214,7 @@ void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when QGemm supports 16-bit. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Gemm", {}}}, @@ -215,6 +235,7 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when QLinearWhere supports 16-bit. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Where", {}}}, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 02a7fb733813c..16c7bd5fce960 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -14,6 +14,12 @@ namespace onnxruntime { namespace QDQ { namespace { + +constexpr bool Is16BitIntType(int32_t data_type) { + return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16) || + (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16); +} + // adjust for an optional input/output that has an entry but does not exist int NumActualValues(const Node& node, bool input) { const auto& defs = input ? node.InputDefs() : node.OutputDefs(); @@ -110,6 +116,17 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + if (dt_input != dt_output) { + return false; + } + + if (!allow_16bit_ && Is16BitIntType(dt_input)) { + return false; + } + const Node& dq_node = *dq_nodes.front(); const Node& q_node = *q_nodes.front(); @@ -124,7 +141,7 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - int num_dq_inputs = NumActualValues(node, true); + constexpr int num_dq_inputs = 1; if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) { return false; } @@ -136,6 +153,12 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, (void)q_nodes; const Node& dq_node = *dq_nodes.front(); + const int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && Is16BitIntType(dt_input)) { + return false; + } auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { return graph_viewer.GetConstantInitializer(initializer_name, true); @@ -154,7 +177,16 @@ bool UnaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - return dt_input == dt_output; + if (dt_input != dt_output) { + return false; + } + + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && Is16BitIntType(dt_input)) { + return false; + } + + return true; } bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, @@ -168,8 +200,18 @@ bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, int32_t dt_input_1 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_input_2 = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - return dt_input_1 == dt_input_2 && - dt_input_1 == dt_output; + + // All input and output types must match. + if (dt_input_1 != dt_input_2 || dt_input_1 != dt_output) { + return false; + } + + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && Is16BitIntType(dt_input_1)) { + return false; + } + + return true; } bool VariadicNodeGroupSelector::Check(const GraphViewer& graph_viewer, @@ -194,7 +236,17 @@ bool VariadicNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } } - return dt_input == dt_output; + + if (dt_input != dt_output) { + return false; + } + + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && Is16BitIntType(dt_input)) { + return false; + } + + return true; } void InputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { @@ -227,12 +279,19 @@ bool ConvNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } - if (dq_nodes.size() < 3) { // no bias - return true; + if (dq_nodes.size() == 3) { // has bias + int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) { + return false; + } } - int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - return dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32; + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && (Is16BitIntType(dt_input) || Is16BitIntType(dt_weight))) { + return false; + } + + return true; } void ConvSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { @@ -256,6 +315,11 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && (Is16BitIntType(dt_input) || Is16BitIntType(dt_weight))) { + return false; + } + // potential match for QLinearMatMul or MatMulIntegerToFloat bool qlinear = !q_nodes.empty(); @@ -299,6 +363,11 @@ bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && (Is16BitIntType(dt_A) || Is16BitIntType(dt_B))) { + return false; + } + if (dq_nodes.size() < 3) { // no bias return true; } @@ -326,8 +395,18 @@ bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& const int32_t dt_input_1 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); const int32_t dt_input_2 = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); const int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - return dt_input_1 == dt_input_2 && - dt_input_1 == dt_output; + + // All input and output types must match. + if (dt_input_1 != dt_input_2 || dt_input_1 != dt_output) { + return false; + } + + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && Is16BitIntType(dt_input_1)) { + return false; + } + + return true; } bool PadNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 58ebf81508962..d8fefdd8dc3d9 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -52,45 +52,75 @@ class NodeGroupSelector { // Single DQ -> node that does not change data -> Q. // Zero point and scale are constant scalars and must match class DropQDQNodeGroupSelector : public NodeGroupSelector { + public: + explicit DropQDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // Single DQ -> node. class DropDQNodeGroupSelector : public NodeGroupSelector { + public: + explicit DropDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // single input. default is to only support uint8. class UnaryNodeGroupSelector : public NodeGroupSelector { + public: + explicit UnaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // 2 DQ nodes providing input -> node -> Q class BinaryNodeGroupSelector : public NodeGroupSelector { + public: + explicit BinaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // Variadic DQ nodes -> node -> Q class VariadicNodeGroupSelector : public NodeGroupSelector { + public: + explicit VariadicNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // DQ nodes for X, W and optionally B -> node -> Q class ConvNodeGroupSelector : public NodeGroupSelector { public: // default to 'true' - ConvNodeGroupSelector(bool int8_allowed = true) : int8_allowed_(int8_allowed) {} + ConvNodeGroupSelector(bool int8_allowed = true, bool allow_16bit = true) + : int8_allowed_(int8_allowed), allow_16bit_(allow_16bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -98,16 +128,20 @@ class ConvNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool int8_allowed_; + bool allow_16bit_; }; class WhereNodeGroupSelector : public NodeGroupSelector { public: - WhereNodeGroupSelector() = default; + explicit WhereNodeGroupSelector(bool allow_16bit = true) + : allow_16bit_(allow_16bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; class PadNodeGroupSelector : public NodeGroupSelector { @@ -125,9 +159,11 @@ class PadNodeGroupSelector : public NodeGroupSelector { class MatMulNodeGroupSelector : public NodeGroupSelector { public: MatMulNodeGroupSelector(bool int8_allowed = true, - bool matmulintegertofloat_allowed = false) + bool matmulintegertofloat_allowed = false, + bool allow_16bit = true) : int8_allowed_(int8_allowed), - matmulintegertofloat_allowed_(matmulintegertofloat_allowed) { + matmulintegertofloat_allowed_(matmulintegertofloat_allowed), + allow_16bit_(allow_16bit) { } private: @@ -136,15 +172,21 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool int8_allowed_; bool matmulintegertofloat_allowed_; + bool allow_16bit_; }; // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmNodeGroupSelector : public NodeGroupSelector { + public: + explicit GemmNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // Input: DQ nodes for input, scale, and B @@ -207,28 +249,33 @@ class BaseSelector : public NodeSelector { class DropQDQNodesSelector : public BaseSelector { public: - DropQDQNodesSelector() : BaseSelector(std::make_unique()) {} + explicit DropQDQNodesSelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} }; class DropDQNodesSelector : public BaseSelector { public: - DropDQNodesSelector() : BaseSelector(std::make_unique()) {} + explicit DropDQNodesSelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} }; class UnarySelector : public BaseSelector { public: - UnarySelector() : BaseSelector(std::make_unique()) {} + explicit UnarySelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} }; class BinarySelector : public BaseSelector { public: - BinarySelector() : BaseSelector(std::make_unique()) {} + explicit BinarySelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} }; // Variadic DQ nodes -> node -> Q class InputVariadicSelector : public BaseSelector { public: - InputVariadicSelector() : BaseSelector(std::make_unique()) {} + explicit InputVariadicSelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; @@ -244,46 +291,36 @@ class OutputVariadicSelector : public BaseSelector { // DQ nodes for X, W and optionally B -> node -> Q class ConvSelector : public BaseSelector { public: - ConvSelector(bool int8_allowed = false) : BaseSelector(std::make_unique(int8_allowed)) {} + ConvSelector(bool int8_allowed = false, bool allow_16bit = false) + : BaseSelector(std::make_unique(int8_allowed, allow_16bit)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; + class WhereSelector : public BaseSelector { public: - WhereSelector() : BaseSelector(std::make_unique()) {} + explicit WhereSelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} }; + // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not class MatMulSelector : public BaseSelector { public: - MatMulSelector(bool int8_allowed) - : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true)) {} + MatMulSelector(bool int8_allowed, bool allow_16bit = false) + : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true, + allow_16bit)) {} }; // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmSelector : public BaseSelector { public: - GemmSelector() - : BaseSelector(std::make_unique()) {} + explicit GemmSelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; -// Input: DQ nodes for input, scale, and B (bias) -// Output: Q node for output -class InstanceNormalizationSelector : public BaseSelector { - public: - InstanceNormalizationSelector() - : BaseSelector(std::make_unique()) {} -}; - -// DQ nodes for X, W and optionally B, (mean, var not required) -> node -> Q -class BatchNormalizationSelector : public BaseSelector { - public: - BatchNormalizationSelector(bool int8_allowed = false) - : BaseSelector(std::make_unique(int8_allowed)) {} -}; - } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 3723ee6032582..2c11bf144999e 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -1195,7 +1195,7 @@ bool TransposeQuantizeDequantizeAxis(const api::GraphRef& graph, const std::vect static bool HandleQuantizeDequantizeAxis(const api::GraphRef& graph, const std::vector& perm, api::NodeRef& node, int64_t opset) { if (opset < 13) { - // no `axis` value until opset 13 + // no `axis` attribute until opset 13 return true; } diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index 67a9a5991939a..a0d75e8cc0e69 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -5,13 +5,47 @@ #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" #include "core/framework/float16.h" -#include "core/providers/cpu/quantization/quantize_linear.h" +#include "core/framework/op_kernel.h" #include "core/providers/common.h" #include "core/mlas/inc/mlas.h" #include "core/util/qmath.h" namespace onnxruntime { +template +class DequantizeLinear final : public OpKernel { + public: + explicit DequantizeLinear(const OpKernelInfo& info) : OpKernel(info) { + if (!info.GetAttr("axis", &axis_).IsOK()) { + axis_ = 1; + } + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t axis_; +}; + +template +class QuantizeLinear final : public OpKernel { + public: + explicit QuantizeLinear(const OpKernelInfo& info) : OpKernel(info) { + if (!info.GetAttr("axis", &axis_).IsOK()) { + axis_ = 1; + } + if (!info.GetAttr("saturate", &saturate_).IsOK()) { + saturate_ = 1; + } + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t axis_; + int64_t saturate_; +}; + static void PrepareForQDQ(const TensorShape& input_shape, const Tensor& scale, const Tensor* zero_point_ptr, @@ -86,6 +120,59 @@ REGISTER_DEQUANTIZELINEAR_VERSIONED(int8_t) REGISTER_DEQUANTIZELINEAR_VERSIONED(uint8_t) REGISTER_DEQUANTIZELINEAR_VERSIONED(int32_t) +#if !defined(DISABLE_CONTRIB_OPS) +namespace contrib { + +// Register alternate MS domain versions of the DequantizeLinear kernel. +// The MS domain versions additionally support 16-bit integer quantization types. +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + uint8_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + int8_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + uint16_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + int16_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + int32_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +} // namespace contrib +#endif // !defined(DISABLE_CONTRIB_OPS) + template struct DequantizeLinearApply { void op(int64_t N, int64_t broadcast_dim, int64_t block_size, const T* input, const OutT* scale, OutT* output, const T* zero_point) { @@ -220,6 +307,49 @@ REGISTER_QUANTIZELINEAR(Float8E5M2FNUZ) REGISTER_QUANTIZELINEAR_VERSIONED(int8_t) REGISTER_QUANTIZELINEAR_VERSIONED(uint8_t) +#if !defined(DISABLE_CONTRIB_OPS) +namespace contrib { + +// Register alternate MS domain versions of the QuantizeLinear kernel. +// The MS domain versions additionally support 16-bit integer quantization types. +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + uint8_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + int8_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + uint16_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + int16_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); +} // namespace contrib +#endif // !defined(DISABLE_CONTRIB_OPS) + template void ParQuantizeLinear(const InputType* Input, OutputType* Output, @@ -279,5 +409,4 @@ Status QuantizeLinear::Compute(OpKernelContext* ctx) const { return Status::OK(); } - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.h b/onnxruntime/core/providers/cpu/quantization/quantize_linear.h deleted file mode 100644 index 60e9d09665ab2..0000000000000 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/framework/op_kernel.h" -#include "core/util/math_cpuonly.h" - -namespace onnxruntime { - -template -class DequantizeLinear final : public OpKernel { - public: - DequantizeLinear(const OpKernelInfo& info) : OpKernel(info) { - if (!info.GetAttr("axis", &axis_).IsOK()) { - axis_ = 1; - } - } - - Status Compute(OpKernelContext* context) const override; - - private: - int64_t axis_; -}; - -template -class QuantizeLinear final : public OpKernel { - public: - QuantizeLinear(const OpKernelInfo& info) : OpKernel(info) { - if (!info.GetAttr("axis", &axis_).IsOK()) { - axis_ = 1; - } - if (!info.GetAttr("saturate", &saturate_).IsOK()) { - saturate_ = 1; - } - } - - Status Compute(OpKernelContext* context) const override; - - private: - int64_t axis_; - int64_t saturate_; -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 9dccd7c47fbb6..0674fe02d093d 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -318,7 +318,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad); @@ -327,6 +326,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -580,15 +584,17 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/binary.cc b/onnxruntime/core/providers/js/operators/binary.cc index 98f7ca6e613b0..e61cb1094736d 100644 --- a/onnxruntime/core/providers/js/operators/binary.cc +++ b/onnxruntime/core/providers/js/operators/binary.cc @@ -6,14 +6,13 @@ namespace onnxruntime { namespace js { -#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ - ONNX_OPERATOR_KERNEL_EX( \ - OP_TYPE, \ - kOnnxDomain, \ - VERSION, \ - kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType()}), \ +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \ KERNEL_CLASS); #define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ @@ -22,8 +21,7 @@ namespace js { kOnnxDomain, \ VERSION_FROM, VERSION_TO, \ kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType()}), \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \ KERNEL_CLASS); JSEP_KERNEL_IMPL(Add, Add) diff --git a/onnxruntime/core/providers/js/operators/if.cc b/onnxruntime/core/providers/js/operators/if.cc new file mode 100644 index 0000000000000..ef072bb1635dd --- /dev/null +++ b/onnxruntime/core/providers/js/operators/if.cc @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "if.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); +// output shape rules requiring the output shapes of the 'THEN' and 'ELSE' +// branches to be the same were relaxed in opset-11 +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 11, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +// opset-13 supports sequence type for If's subgraph outputs +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 13, 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +// opset-19 supports float8 +ONNX_OPERATOR_KERNEL_EX(If, + kOnnxDomain, + 19, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +Status If::Compute(OpKernelContext* ctx) const { + // call the base CPU version. + return onnxruntime::If::Compute(ctx); +} + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/if.h b/onnxruntime/core/providers/js/operators/if.h new file mode 100644 index 0000000000000..d060444ccc1d2 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/if.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/providers/js/js_kernel.h" +#include "core/common/common.h" +#include "core/providers/cpu/controlflow/if.h" + +namespace onnxruntime { +class SessionState; + +namespace js { + +// Use the CPU implementation for the logic +class If final : public onnxruntime::If { + public: + If(const OpKernelInfo& info) : onnxruntime::If(info) {} + + Status Compute(OpKernelContext* ctx) const override; +}; +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index 5e972e43e4566..e9bbfabcf86bd 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -6,22 +6,29 @@ namespace onnxruntime { namespace js { -#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ +#define JSEP_ELEMENTWISE_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ ONNX_OPERATOR_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ KERNEL_CLASS); -#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \ + KERNEL_CLASS); + +#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \ KERNEL_CLASS); #define JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ ONNX_OPERATOR_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}), \ KERNEL_CLASS); @@ -29,6 +36,7 @@ namespace js { ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}), \ KERNEL_CLASS); // math @@ -42,115 +50,115 @@ JSEP_ELEMENTWISE_MULTI_TYPED_VERSIONED_KERNEL(Neg, 6, 12, Neg) JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(Neg, 13, Neg) JSEP_KERNEL_IMPL(Floor, Floor) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, float, Floor) -JSEP_ELEMENTWISE_KERNEL(Floor, 13, float, Floor) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, Floor) +JSEP_ELEMENTWISE_KERNEL(Floor, 13, Floor) JSEP_KERNEL_IMPL(Ceil, Ceil) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, float, Ceil) -JSEP_ELEMENTWISE_KERNEL(Ceil, 13, float, Ceil) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, Ceil) +JSEP_ELEMENTWISE_KERNEL(Ceil, 13, Ceil) JSEP_KERNEL_IMPL(Reciprocal, Reciprocal) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, float, Reciprocal) -JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, float, Reciprocal) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, Reciprocal) +JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, Reciprocal) JSEP_KERNEL_IMPL(Sqrt, Sqrt) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, float, Sqrt) -JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, float, Sqrt) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, Sqrt) +JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, Sqrt) JSEP_KERNEL_IMPL(Exp, Exp) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, float, Exp) -JSEP_ELEMENTWISE_KERNEL(Exp, 13, float, Exp) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, Exp) +JSEP_ELEMENTWISE_KERNEL(Exp, 13, Exp) JSEP_KERNEL_IMPL(Erf, Erf) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, float, Erf) -JSEP_ELEMENTWISE_KERNEL(Erf, 13, float, Erf) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, Erf) +JSEP_ELEMENTWISE_KERNEL(Erf, 13, Erf) JSEP_KERNEL_IMPL(Sigmoid, Sigmoid) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, float, Sigmoid) -JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, float, Sigmoid) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid) +JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid) JSEP_KERNEL_IMPL(Log, Log) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, float, Log) -JSEP_ELEMENTWISE_KERNEL(Log, 13, float, Log) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log) +JSEP_ELEMENTWISE_KERNEL(Log, 13, Log) JSEP_KERNEL_IMPL(Sin, Sin) -JSEP_ELEMENTWISE_KERNEL(Sin, 7, float, Sin) +JSEP_ELEMENTWISE_KERNEL(Sin, 7, Sin) JSEP_KERNEL_IMPL(Cos, Cos) -JSEP_ELEMENTWISE_KERNEL(Cos, 7, float, Cos) +JSEP_ELEMENTWISE_KERNEL(Cos, 7, Cos) JSEP_KERNEL_IMPL(Tan, Tan) -JSEP_ELEMENTWISE_KERNEL(Tan, 7, float, Tan) +JSEP_ELEMENTWISE_KERNEL(Tan, 7, Tan) JSEP_KERNEL_IMPL(Asin, Asin) -JSEP_ELEMENTWISE_KERNEL(Asin, 7, float, Asin) +JSEP_ELEMENTWISE_KERNEL(Asin, 7, Asin) JSEP_KERNEL_IMPL(Acos, Acos) -JSEP_ELEMENTWISE_KERNEL(Acos, 7, float, Acos) +JSEP_ELEMENTWISE_KERNEL(Acos, 7, Acos) JSEP_KERNEL_IMPL(Atan, Atan) -JSEP_ELEMENTWISE_KERNEL(Atan, 7, float, Atan) +JSEP_ELEMENTWISE_KERNEL(Atan, 7, Atan) JSEP_KERNEL_IMPL(Sinh, Sinh) -JSEP_ELEMENTWISE_KERNEL(Sinh, 9, float, Sinh) +JSEP_ELEMENTWISE_KERNEL(Sinh, 9, Sinh) JSEP_KERNEL_IMPL(Cosh, Cosh) -JSEP_ELEMENTWISE_KERNEL(Cosh, 9, float, Cosh) +JSEP_ELEMENTWISE_KERNEL(Cosh, 9, Cosh) JSEP_KERNEL_IMPL(Asinh, Asinh) -JSEP_ELEMENTWISE_KERNEL(Asinh, 9, float, Asinh) +JSEP_ELEMENTWISE_KERNEL(Asinh, 9, Asinh) JSEP_KERNEL_IMPL(Acosh, Acosh) -JSEP_ELEMENTWISE_KERNEL(Acosh, 9, float, Acosh) +JSEP_ELEMENTWISE_KERNEL(Acosh, 9, Acosh) JSEP_KERNEL_IMPL(Atanh, Atanh) -JSEP_ELEMENTWISE_KERNEL(Atanh, 9, float, Atanh) +JSEP_ELEMENTWISE_KERNEL(Atanh, 9, Atanh) JSEP_KERNEL_IMPL(Tanh, Tanh) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, float, Tanh) -JSEP_ELEMENTWISE_KERNEL(Tanh, 13, float, Tanh) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, Tanh) +JSEP_ELEMENTWISE_KERNEL(Tanh, 13, Tanh) JSEP_KERNEL_IMPL(Not, Not) -JSEP_ELEMENTWISE_KERNEL(Not, 1, bool, Not) +JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, float, ClipV10) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); ONNX_OPERATOR_KERNEL_EX(Clip, kOnnxDomain, 13, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(Elu, Elu, alpha, 1.0) -JSEP_ELEMENTWISE_KERNEL(Elu, 6, float, Elu) +JSEP_ELEMENTWISE_KERNEL(Elu, 6, Elu) JSEP_KERNEL_IMPL(Relu, Relu) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, float, Relu) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, float, Relu) -JSEP_ELEMENTWISE_KERNEL(Relu, 14, float, Relu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, Relu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, Relu) +JSEP_ELEMENTWISE_KERNEL(Relu, 14, Relu) JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(LeakyRelu, LeakyRelu, alpha, 0.01) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, float, LeakyRelu) -JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, float, LeakyRelu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, LeakyRelu) +JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, LeakyRelu) JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(ThresholdedRelu, ThresholdedRelu, alpha, 1.0) -JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, float, ThresholdedRelu) +JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, ThresholdedRelu) } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 556a86bb1519b..8081033c35618 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -30,6 +30,12 @@ class SimpleOpBuilder : public BaseOpBuilder { private: Status ExplicitOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const; + Status ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + std::vector&& param_tensor_names, + const logging::Logger& logger, + bool do_op_validation) const ORT_MUST_USE_RESULT; static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"}; static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; @@ -279,10 +285,120 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w ORT_RETURN_IF_ERROR(ProcessGridSampleAttributes(qnn_model_wrapper, node_unit, param_tensor_names)); } - ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, - std::move(input_names), - std::move(param_tensor_names), - logger, do_op_validation, GetQnnOpType(op_type))); + if (op_type == "Sigmoid" || op_type == "Tanh") { + // QNN requires 16-bit QDQ Sigmoid and Tanh to use specific output scale and zero-point values + // regardless of floating-point range. + return ProcessSigmoidOrTanhOutput(qnn_model_wrapper, + node_unit, + std::move(input_names), + std::move(param_tensor_names), + logger, do_op_validation); + } + + return ProcessOutputs(qnn_model_wrapper, node_unit, + std::move(input_names), + std::move(param_tensor_names), + logger, do_op_validation, GetQnnOpType(op_type)); +} + +/** + * Overrides offset and scale quantization parameters for operators (e.g., Sigmoid or Tanh) that require + * specific values. Returns true if the quantization parameters were overridden. + * + * \param op_type The ONNX operator type. + * \param qnn_data_type The QNN tensor data type. + * \param quant_params Output scale/offset parameter that may be overridden. + * \return True if the offset and scale were overridden. + */ +static bool OverrideQuantParams(const std::string& op_type, Qnn_DataType_t qnn_data_type, + Qnn_ScaleOffset_t& quant_params) { + const int32_t orig_offset = quant_params.offset; + const float orig_scale = quant_params.scale; + + if (op_type == "Sigmoid") { + switch (qnn_data_type) { + case QNN_DATATYPE_UFIXED_POINT_16: + quant_params.offset = 0; + quant_params.scale = 1.0f / 65536.0f; + break; + case QNN_DATATYPE_SFIXED_POINT_16: + quant_params.offset = 0; + quant_params.scale = 1.0f / 32768.0f; + break; + default: + break; // Do nothing. + } + } + + if (op_type == "Tanh") { + switch (qnn_data_type) { + case QNN_DATATYPE_UFIXED_POINT_16: + quant_params.offset = -32768; + quant_params.scale = 1.0f / 32768.0f; + break; + case QNN_DATATYPE_SFIXED_POINT_16: + quant_params.offset = 0; + quant_params.scale = 1.0f / 32768.0f; + break; + default: + break; // Do nothing. + } + } + + return quant_params.offset != orig_offset || quant_params.scale != orig_scale; +} + +/** + * Processes the output for Sigmoid or Tanh operators and creates the corresponding QNN operator. + * These operator types are handled separately because QNN requires 16-bit QDQ Sigmoid and Tanh operators to use + * specific scale and zero-point values regardless of floating-point range. + * + * \param qnn_model_wrapper The QNN model wrapper object. + * \param node_unit The QDQ node unit for the Sigmoid or Tanh node. + * \param input_names List of input names. + * \param param_tensor_names List of param tensor names. + * \param logger Logger used to report information. + * \param do_op_validation True if the new QNN node should be validated. + */ +Status SimpleOpBuilder::ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + std::vector&& param_tensor_names, + const logging::Logger& logger, + bool do_op_validation) const { + const std::string& op_type = node_unit.OpType(); + const auto& output = node_unit.Outputs()[0]; + const std::string& output_name = output.node_arg.Name(); + + OnnxInputInfo output_info = {}; + + // TODO(adrianlizarraga): Rename GetOnnxInputInfo() since it can be used for outputs as well. + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(output, output_info)); + + if (output_info.quant_param.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + if (OverrideQuantParams(op_type, output_info.qnn_data_type, output_info.quant_param.scaleOffsetEncoding)) { + const int32_t offset = output_info.quant_param.scaleOffsetEncoding.offset; + const float scale = output_info.quant_param.scaleOffsetEncoding.scale; + + LOGS(logger, VERBOSE) << "QNN requires that 16-bit quantized " << op_type << " operators use offset/scale values " + << "of <" << offset << ", " << scale << ">. QNN EP will override the original values."; + } + } + + Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ + : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, output_info.qnn_data_type, output_info.quant_param, + std::move(output_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit), + QNN_OP_PACKAGE_NAME_QTI_AISW, + GetQnnOpType(op_type), + std::move(input_names), + {output_name}, + std::move(param_tensor_names), + do_op_validation), + "Failed to add node."); + return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index eebe75d839b12..9d339387b0a43 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -301,6 +301,16 @@ bool QnnModelWrapper::ProcessOffset(const std::string& offset_name, offset_value = 0 - (uint8_span.data()[0]); break; } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + auto uint16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + offset_value = -static_cast(uint16_span.data()[0]); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + auto int16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + offset_value = -static_cast(int16_span.data()[0]); + break; + } case ONNX_NAMESPACE::TensorProto_DataType_INT32: { auto int32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); offset_value = -(int32_span.data()[0]); diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 16d458a3401b6..43f6f6360633c 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -420,8 +420,10 @@ def __init__( except (ValueError, RuntimeError) as e: if self._enable_fallback: try: + print("*************** EP Error ***************") print(f"EP Error {e} when using {providers}") print(f"Falling back to {self._fallback_providers} and retrying.") + print("****************************************") self._create_inference_session(self._fallback_providers, None) # Fallback only once. self.disable_fallback() @@ -434,11 +436,26 @@ def __init__( def _create_inference_session(self, providers, provider_options, disabled_optimizers=None): available_providers = C.get_available_providers() - # Tensorrt can fall back to CUDA. All others fall back to CPU. + # Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. if "TensorrtExecutionProvider" in available_providers: - self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + if any( + provider == "CUDAExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider") + for provider in providers + ): + self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + self._fallback_providers = ["CPUExecutionProvider"] + # MIGraphX can fall back to ROCM if it's explicitly assigned. All others fall back to CPU. elif "MIGraphXExecutionProvider" in available_providers: - self._fallback_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] + if any( + provider == "ROCMExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "ROCMExecutionProvider") + for provider in providers + ): + self._fallback_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] + else: + self._fallback_providers = ["CPUExecutionProvider"] else: self._fallback_providers = ["CPUExecutionProvider"] diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 924d4c72b6390..2d1e418f9d2b4 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -104,7 +104,7 @@ def __init__( ) self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"] self.is_weight_symmetric = ( - weight_qType in (QuantType.QInt8, QuantType.QFLOAT8E4M3FN) + weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN) if "WeightSymmetric" not in self.extra_options else self.extra_options["WeightSymmetric"] ) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index f87a9d8228bac..e595b580b20df 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -25,6 +25,7 @@ add_quant_output_suffix, add_quant_suffix, find_by_name, + ms_domain, ) from .registry import CreateQDQQuantizer @@ -119,6 +120,20 @@ def __init__( else extra_options["QDQOpTypePerChannelSupportToAxis"] ) + self.qdq_op_domain = ms_domain if extra_options.get("UseQDQContribOps", False) else None + + # The ONNX spec does not yet support 16-bit Q/DQ ops. So, must override the Q/DQ op domain to 'com.microsoft' + # if the activation or weight types are 16-bit integers. + # TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support. + int16_types = (TensorProto.UINT16, TensorProto.INT16) + if not self.qdq_op_domain and (self.activation_qType in int16_types or self.weight_qType in int16_types): + logging.warning( + "ONNX QuantizeLinear and DequantizeLinear operators do not support 16-bit integer quantization types. " + f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to " + "enable support." + ) + self.qdq_op_domain = ms_domain + def _is_tensor_quantizable(self, tensor_name): """ Check if tensor can be quantized @@ -249,6 +264,7 @@ def _create_qdq_nodes( [q_output], quant_node_name, axis=axis, + domain=self.qdq_op_domain, ) dequant_node = onnx.helper.make_node( DEQUANT_OP_NAME, @@ -256,6 +272,7 @@ def _create_qdq_nodes( [dq_output], dequant_node_name, axis=axis, + domain=self.qdq_op_domain, ) self.model.add_nodes([qlinear_node, dequant_node]) @@ -300,6 +317,7 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): [weight_dequant_output], add_dequant_suffix(weight_name), axis=axis, + domain=self.qdq_op_domain, ) self.model.add_node(dequant_node) @@ -443,6 +461,7 @@ def _quantize_bias_tensors(self): [bias_name], node_name, axis=quant_value.axis, + domain=self.qdq_op_domain, ) else: dequant_node = onnx.helper.make_node( @@ -450,6 +469,7 @@ def _quantize_bias_tensors(self): inputs, [bias_name], node_name, + domain=self.qdq_op_domain, ) else: raise RuntimeError(f"Unexpected operator type {quant_value.node_type!r}.") diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 4d5bcca29618f..74e54c3f1fa37 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -72,6 +72,8 @@ class QuantType(Enum): QInt8 = 0 QUInt8 = 1 QFLOAT8E4M3FN = 2 + QInt16 = 3 + QUInt16 = 4 def __str__(self): return self.name @@ -89,6 +91,10 @@ def tensor_type(self): return TensorProto.INT8 if self == QuantType.QUInt8: return TensorProto.UINT8 + if self == QuantType.QUInt16: + return TensorProto.UINT16 + if self == QuantType.QInt16: + return TensorProto.INT16 if self == QuantType.QFLOAT8E4M3FN: return TensorProto.FLOAT8E4M3FN raise ValueError(f"Unexpected value qtype={self!r}.") @@ -112,12 +118,35 @@ def from_string(format): ONNX_TYPE_TO_NP_TYPE = { onnx_proto.TensorProto.INT8: numpy.dtype("int8"), onnx_proto.TensorProto.UINT8: numpy.dtype("uint8"), + onnx_proto.TensorProto.INT16: numpy.dtype("int16"), + onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"), onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn, } +ONNX_INT_TYPE_RANGE = { + onnx_proto.TensorProto.UINT8: (0, 255), + onnx_proto.TensorProto.INT8: (-128, 127), + onnx_proto.TensorProto.UINT16: (0, 65535), + onnx_proto.TensorProto.INT16: (-32768, 32767), +} + +ONNX_INT_TYPE_SYMMETRIC_RANGE = { + onnx_proto.TensorProto.INT8: (-127, 127), + onnx_proto.TensorProto.INT16: (-32767, 32767), +} + +ONNX_INT_TYPE_REDUCED_RANGE = { + onnx_proto.TensorProto.UINT8: (0, 127), + onnx_proto.TensorProto.INT8: (-64, 64), + onnx_proto.TensorProto.UINT16: (0, 32767), + onnx_proto.TensorProto.INT16: (-16384, 16384), +} + def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): - assert qType in ONNX_TYPE_TO_NP_TYPE, f"Unexpected data type {qType} requested. Only INT8 and UINT8 are supported." + assert ( + qType in ONNX_TYPE_TO_NP_TYPE + ), f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported." if qType in ( onnx_proto.TensorProto.FLOAT8E4M3FN, onnx_proto.TensorProto.FLOAT8E4M3FNUZ, @@ -146,8 +175,10 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): return ref.run(None, {"X": arr.astype(numpy.float32), "scale": scale.astype(numpy.float32)})[0] else: dtype = ONNX_TYPE_TO_NP_TYPE[qType] - cliplow = max(0 if dtype == numpy.uint8 else -127, -127 if low is None else low) - cliphigh = min(255 if dtype == numpy.uint8 else 127, 255 if high is None else high) + (qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True) + + cliplow = max(qmin, low) if low is not None else qmin + cliphigh = min(qmax, high) if high is not None else qmax arr_fp32 = numpy.asarray((arr.astype(numpy.float32) / scale).round() + zero_point) numpy.clip(arr_fp32, cliplow, cliphigh, out=arr_fp32) return arr_fp32.astype(dtype) @@ -267,7 +298,7 @@ def quantize_data(data, qType, symmetric, reduce_range=False): ) return rmin, rmax, zero_point, scale, quantized_data - if qType in (TensorProto.INT8, TensorProto.UINT8): + if qType in (TensorProto.INT8, TensorProto.UINT8, TensorProto.INT16, TensorProto.UINT16): if len(data): qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric) zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric) @@ -283,18 +314,22 @@ def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False): # noqa :parameter qType: onnx.onnx_pb.TensorProto.UINT8 or onnx.onnx_pb.TensorProto.UINT8 :return: qmin, qmax """ - if qType == onnx_proto.TensorProto.UINT8: - (qmin, qmax) = (0, 127) if reduce_range else (0, 255) - elif qType == onnx_proto.TensorProto.INT8: - if symmetric: - (qmin, qmax) = (-64, 64) if reduce_range else (-127, 127) - else: - (qmin, qmax) = (-64, 64) if reduce_range else (-128, 127) - elif qType == onnx_proto.TensorProto.FLOAT8E4M3FN: + if qType == onnx_proto.TensorProto.FLOAT8E4M3FN: raise NotImplementedError("This function is not implemented for float 8 as not needed.") + + qrange = None + + if reduce_range: + qrange = ONNX_INT_TYPE_REDUCED_RANGE.get(qType) + elif symmetric and qType in ONNX_INT_TYPE_SYMMETRIC_RANGE: + qrange = ONNX_INT_TYPE_SYMMETRIC_RANGE[qType] else: - raise ValueError(f"Unexpected data type {qType} requested. Only INT8 and UINT8 are supported.") - return qmin, qmax + qrange = ONNX_INT_TYPE_RANGE.get(qType) + + if not qrange: + raise ValueError(f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported.") + + return qrange def get_qrange_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802 diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 6b1646aec9679..706047fe32400 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -240,6 +240,11 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua f"weight_type={weight_type}!=QuantType.QFLOAT8E4M3FN" ) + q16_types = [QuantType.QInt16, QuantType.QUInt16] + + if (activation_type in q16_types or weight_type in q16_types) and quant_format != QuantFormat.QDQ: + raise ValueError("Only QuantFormat.QDQ supports 16-bit quantization types.") + if activation_type == QuantType.QInt8 and weight_type == QuantType.QInt8 and quant_format != QuantFormat.QDQ: logging.warning( "Please use QuantFormat.QDQ for activation type QInt8 and weight type QInt8. " @@ -356,6 +361,11 @@ def quantize_static( SmoothQuantFolding = True/False : Default is True. It only works if SmoothQuant is True. If enabled, inserted Mul ops during SmoothQuant will be folded into the previous op if the previous op is foldable. + UseQDQContribOps = True/False : + Default is False. If enabled, the inserted QuantizeLinear and DequantizeLinear ops will have the + `com.microsoft` domain, which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear + contrib op implementations. The contrib op implementations may support features not standardized + into the ONNX specification (e.g., 16-bit quantization types). """ if activation_type == QuantType.QFLOAT8E4M3FN or weight_type == QuantType.QFLOAT8E4M3FN: if calibrate_method != CalibrationMethod.Distribution: diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index f4d3f2fa1c317..67d3c95922a87 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -8,7 +8,10 @@ import logging import os import random +import sys +import time import timeit +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from datetime import datetime from enum import Enum @@ -439,68 +442,127 @@ def get_gpu_info() -> Optional[List[Dict[str, Any]]]: return None -def measure_memory(is_gpu, func): - class MemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring +class MemoryMonitor(ABC): + def __init__(self, keep_measuring=True): + self.keep_measuring = keep_measuring - def measure_cpu_usage(self): - import psutil + def measure_cpu_usage(self): + import psutil - max_usage = 0 + max_usage = 0 + while True: + max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2) + sleep(0.005) # 5ms + if not self.keep_measuring: + break + return max_usage + + @abstractmethod + def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: + raise NotImplementedError() + + +class CudaMemoryMonitor(MemoryMonitor): + def __init__(self, keep_measuring=True): + super().__init__(keep_measuring) + + def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: + from py3nvml.py3nvml import ( + NVMLError, + nvmlDeviceGetCount, + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlDeviceGetName, + nvmlInit, + nvmlShutdown, + ) + + max_gpu_usage = [] + gpu_name = [] + try: + nvmlInit() + device_count = nvmlDeviceGetCount() + if not isinstance(device_count, int): + logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}") + return None + + max_gpu_usage = [0 for i in range(device_count)] + gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)] while True: - max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2) + for i in range(device_count): + info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)) + if isinstance(info, str): + logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}") + return None + max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2) sleep(0.005) # 5ms if not self.keep_measuring: break - return max_usage - - def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: - from py3nvml.py3nvml import ( - NVMLError, - nvmlDeviceGetCount, - nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, - nvmlDeviceGetName, - nvmlInit, - nvmlShutdown, - ) + nvmlShutdown() + return [ + { + "device_id": i, + "name": gpu_name[i], + "max_used_MB": max_gpu_usage[i], + } + for i in range(device_count) + ] + except NVMLError as error: + logger.error("Error fetching GPU information using nvml: %s", error) + return None - max_gpu_usage = [] - gpu_name = [] - try: - nvmlInit() - device_count = nvmlDeviceGetCount() - if not isinstance(device_count, int): - logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}") - return None - - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)] - while True: - for i in range(device_count): - info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)) - if isinstance(info, str): - logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}") - return None - max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2) - sleep(0.005) # 5ms - if not self.keep_measuring: - break - nvmlShutdown() - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] - except NVMLError as error: - logger.error("Error fetching GPU information using nvml: %s", error) - return None - monitor = MemoryMonitor(False) +class RocmMemoryMonitor(MemoryMonitor): + def __init__(self, keep_measuring=True): + super().__init__(keep_measuring) + rocm_smi_path = "/opt/rocm/libexec/rocm_smi" + if os.path.exists(rocm_smi_path): + if rocm_smi_path not in sys.path: + sys.path.append(rocm_smi_path) + try: + import rocm_smi + + self.rocm_smi = rocm_smi + self.rocm_smi.initializeRsmi() + except ImportError: + self.rocm_smi = None + + def get_used_memory(self, dev): + if self.rocm_smi is None: + return -1 + return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024 + + def measure_gpu_usage(self): + if self.rocm_smi is None: + return None + + device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0 + max_gpu_usage = [0 for i in range(device_count)] + gpu_name = [f"GPU{i}" for i in range(device_count)] + while True: + for i in range(device_count): + max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) + time.sleep(0.005) # 2ms + if not self.keep_measuring: + break + return [ + { + "device_id": i, + "name": gpu_name[i], + "max_used_MB": max_gpu_usage[i], + } + for i in range(device_count) + ] + + +def measure_memory(is_gpu, func, monitor_type="cuda"): + memory_monitor_type = None + if monitor_type == "rocm": + memory_monitor_type = RocmMemoryMonitor + else: + memory_monitor_type = CudaMemoryMonitor + + monitor = memory_monitor_type(False) if is_gpu: memory_before_test = monitor.measure_gpu_usage() @@ -508,7 +570,7 @@ def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: return None with ThreadPoolExecutor() as executor: - monitor = MemoryMonitor() + monitor = memory_monitor_type() mem_thread = executor.submit(monitor.measure_gpu_usage) try: fn_thread = executor.submit(func) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 63c991167d235..c0cabbb5e9759 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -993,7 +993,11 @@ def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto encoder.remove_duplicated_initializer(signature_cache1) decoder.remove_duplicated_initializer(signature_cache2) initializers = remove_shared_initializers( - decoder.model.graph, encoder.model.graph, "s_", signature_cache1, signature_cache2 + decoder.model.graph, + encoder.model.graph, + shared_prefix="s_", + signature_cache1=signature_cache1, + signature_cache2=signature_cache2, ) return initializers diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 07995f0a38e26..283528bea7465 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -11,7 +11,7 @@ import psutil import torch import whisper -from benchmark_helper import setup_logger +from benchmark_helper import measure_memory, setup_logger from onnxruntime_extensions import get_library_path from optimum.onnxruntime import ORTModelForSpeechSeq2Seq from torch.profiler import ProfilerActivity, profile, record_function @@ -19,7 +19,6 @@ from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor import onnxruntime as ort -from onnxruntime.transformers.benchmark_helper import measure_memory logger = logging.getLogger(__name__) @@ -123,6 +122,9 @@ def get_model(args: argparse.Namespace): if args.verbose: sess_options.log_verbosity_level = 1 sess_options.log_severity_level = 1 + if args.tune: + ort.set_default_logger_severity(0) + ort.set_default_logger_verbosity(0) else: raise Exception(f"Cannot recognize {args.benchmark_type}") @@ -159,6 +161,9 @@ def get_model(args: argparse.Namespace): def time_fn(args, fn, inputs): + warmup_inputs = inputs[0] if type(inputs) is tuple else inputs + benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs + # Warm up warmup_range = ( range(args.warmup_runs) @@ -167,11 +172,11 @@ def time_fn(args, fn, inputs): ) if args.verbose: - outputs = fn(inputs) + outputs = fn(warmup_inputs) logger.info(outputs) for _ in warmup_range: - fn(inputs) + fn(warmup_inputs) # Benchmark if args.device != "cpu": @@ -184,7 +189,7 @@ def time_fn(args, fn, inputs): else trange(args.num_runs, file=sys.stdout, desc="Benchmark") ) for _ in bench_range: - fn(inputs) + fn(benchmark_inputs) if args.device != "cpu": torch.cuda.synchronize() @@ -244,7 +249,7 @@ def measure_fn(args, fn, inputs): # Measure memory usage gc.collect() torch.cuda.empty_cache() - measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs)) + measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type) # Flush output so memory usage is printed sys.stdout.flush() @@ -255,7 +260,7 @@ def run_hf_inference(args, inputs, model): def get_pred_ids(inputs): # Inference pass with predicted token ids generation predicted_ids = model.generate(**inputs) - return predicted_ids, [""] + return predicted_ids def gen_and_dec(inputs): # Inference pass with generation and decoding @@ -315,7 +320,7 @@ def gen_and_dec(inputs): def run_ort_inference(args, inputs, model): - def prepare_ort_inputs(inputs): + def prepare_ort_inputs(inputs, warmup=False): # Check that all model inputs will be provided model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs())) user_inputs = set(inputs.keys()) @@ -324,6 +329,9 @@ def prepare_ort_inputs(inputs): logger.error(f"The following model inputs are missing: {missing_inputs}") raise Exception("There are missing inputs to the model. Please add them and try again.") + if warmup and args.tune: + inputs["min_length"] = inputs["max_length"] + # Remove unnecessary inputs from model inputs unnecessary_inputs = user_inputs - model_inputs if len(unnecessary_inputs): @@ -352,6 +360,13 @@ def without_io_binding(inputs): outputs = model.run(None, inputs) return outputs + def handle_output(output): + if args.eos_token_id in output: + first_end = np.where(output == args.eos_token_id)[0][0] + return output[: first_end + 1] + + return output + generate_fn = with_io_binding if args.device != "cpu" else without_io_binding ort_inputs = prepare_ort_inputs(inputs) @@ -367,7 +382,12 @@ def without_io_binding(inputs): # ORT evaluation logger.info("\nEvaluating ONNX Runtime...") - time_fn(args, generate_fn, ort_inputs) + ort_evaluate_inputs = ort_inputs + if args.tune: + ort_warmup_inputs = prepare_ort_inputs(inputs, warmup=True) + ort_evaluate_inputs = (ort_warmup_inputs, ort_inputs) + + time_fn(args, generate_fn, ort_evaluate_inputs) ort_outputs = generate_fn(ort_inputs) if args.device != "cpu": ort_outputs = ort_outputs.copy_outputs_to_cpu() @@ -378,7 +398,10 @@ def without_io_binding(inputs): logger.info(f"Transcription: {ort_outputs[0][0]}") else: # convert_to_onnx model produces generated ids - logger.info(f"Generated token length: {len(ort_outputs[0][0])} tokens") + actual_output = handle_output(ort_outputs[0][0]) + logger.info(f"Generated token length: {len(actual_output)} tokens") + transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0] + logger.info(f"Transcription: {transcription}") measure_fn(args, generate_fn, ort_inputs) @@ -483,6 +506,12 @@ def parse_args(): parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display") parser.add_argument("--verbose", default=False, action="store_true") parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files") + parser.add_argument( + "--tune", + default=False, + action="store_true", + help="Only used by ROCm EP, enable TunableOp tuning to select fastest kernel", + ) args = parser.parse_args() @@ -490,13 +519,21 @@ def parse_args(): np.random.seed(args.seed) torch.manual_seed(args.seed) + args.monitor_type = args.device # Set runtime properties if "ort" in args.benchmark_type: args.execution_provider = f"{args.device.upper()}ExecutionProvider" if args.execution_provider == "CUDAExecutionProvider": args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) elif args.execution_provider == "ROCMExecutionProvider": - args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) + args.execution_provider = ( + args.execution_provider, + { + "device_id": args.device_id, + "tunable_op_enable": 1, + "tunable_op_tuning_enable": 1 if args.tune else 0, + }, + ) args.device = "cuda" # Check that model paths have been specified for any benchmarking with ORT @@ -527,6 +564,7 @@ def main(): setattr(args, "target_device", target_device) # noqa: B010 setattr(args, "use_fp16", use_fp16) # noqa: B010 setattr(args, "has_audio_stream", False) # noqa: B010 + setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010 logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}") diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index f12723f1af2df..08d7befec3cfd 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -109,6 +109,8 @@ def get_args(): help="Number of mins to attempt the benchmark before moving on", ) + parser.add_argument("--tune", default=False, action="store_true") + args = parser.parse_args() setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010 @@ -292,6 +294,7 @@ def main(): ort_decoder_input_ids_cmd = ( ["--decoder-input-ids", str(ort_forced_decoder_ids)] if args.language and args.task else [] ) + ort_tune_cmd = ["--tune"] if args.tune else [] all_results = [] for audio_file in os.listdir(args.audio_path): @@ -395,31 +398,35 @@ def main(): # Benchmark ONNX Runtime if args.ort_model_path: - benchmark_cmd = [ # noqa: RUF005 - "python3", - "-m", - "models.whisper.benchmark", - "--audio-path", - audio_path, - "--benchmark-type", - "ort", - "--ort-model-path", - args.ort_model_path, - "--model-name", - args.model_name, - "--precision", - args.precision, - "--device", - args.device, - "--device-id", - str(args.device_id), - "--warmup-runs", - str(args.warmup_runs), - "--num-runs", - str(args.num_runs), - "--log-folder", - args.log_folder, - ] + ort_decoder_input_ids_cmd + benchmark_cmd = ( + [ # noqa: RUF005 + "python3", + "-m", + "models.whisper.benchmark", + "--audio-path", + audio_path, + "--benchmark-type", + "ort", + "--ort-model-path", + args.ort_model_path, + "--model-name", + args.model_name, + "--precision", + args.precision, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + + ort_decoder_input_ids_cmd + + ort_tune_cmd + ) logger.info("Benchmark ONNX Runtime") results = benchmark(args, benchmark_cmd, "onnxruntime", audio_file, duration) all_results.extend(results) diff --git a/onnxruntime/test/contrib_ops/quantize_ops_test.cc b/onnxruntime/test/contrib_ops/quantize_ops_test.cc index af29f972a64cf..64a97ed4f945b 100644 --- a/onnxruntime/test/contrib_ops/quantize_ops_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_ops_test.cc @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" namespace onnxruntime { namespace test { @@ -40,7 +41,31 @@ TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// Scalar zero & scale with int32 +// Test int16 com.microsoft.DequantizeLinear (per tensor) +TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int16_cpu) { + OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{4}; + test.AddInput("x", dims, {-300, -30, -1025, 1270}); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {-1024}, true); + test.AddOutput("y", dims, {1448.0f, 1988.0f, -2.0f, 4588.0f}); + // Disable Tensorrt EP due to error: unsupported data type + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test uint16 com.microsoft.DequantizeLinear (per tensor) +TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_uint16_cpu) { + OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{4}; + test.AddInput("x", dims, {30000, 31000, 32768, 33000}); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {32767}, true); + test.AddOutput("y", dims, {-5534.0f, -3534.0f, 2.0f, 466.0f}); + // Disable Tensorrt EP due to error: unsupported data type + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test int32 DequantizeLinear with scalar zero-point & scale. TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int32_cpu) { OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); std::vector dims{4}; @@ -256,6 +281,60 @@ TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// Test uint16 com.microsoft.QuantizeLinear (per tensor) +TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_uint16) { + OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{12}; + test.AddInput("x", dims, { + 0.f, -128.f, 3.f, -3.f, // rounding half to even + 2.9f, -2.9f, // round < .5 + 3.1f, -3.1f, // round > .5 + 65536.f, -65534.f, // critical point + 70000.f, -70000.f // saturate case + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {32767}, true); + test.AddOutput("y", dims, + {32767, 32703, + 32769, 32765, + 32768, 32766, + 32769, 32765, + 65535, 0, + 65535, 0}); + + // Disable Tensorrt EP due to error: unsupported data type + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test int16 com.microsoft.QuantizeLinear (per tensor) +TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_int16) { + OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{16}; + test.AddInput("x", dims, { + 0.f, -514.f, 3.f, -3.f, // rounding half to even + 2.9f, -2.9f, // round < .5 + 3.1f, -3.1f, // round > .5 + 65022.f, -66046.f, // critical point + 65023.f, -66047.f, // critical point + 65024.f, -66048.f, // critical point + 70000.f, -70000.f // saturate case + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {256}, true); + test.AddOutput("y", dims, + {256, -1, + 258, 254, + 257, 255, + 258, 254, + 32767, -32767, + 32767, -32768, + 32767, -32768, + 32767, -32768}); + + // Disable Tensorrt EP due to error: unsupported data type + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + #ifdef USE_CUDA TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_half_uint8) { OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); diff --git a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp index 55d1a2f4f3608..2832598fef1a9 100644 --- a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp +++ b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp @@ -3,26 +3,26 @@ #include "test_util.h" -template +template class MlasQuantizeLinearTest : public MlasTestBase { private: MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; - void GenerateReference(const float* Input, xint8_t* OutputReference, size_t N, float Scale, xint8_t ZeroPoint) { + void GenerateReference(const float* Input, QuantInt* OutputReference, size_t N, float Scale, QuantInt ZeroPoint) { for (size_t n = 0; n < N; n++) { float FloatValue = std::nearbyintf(Input[n] / Scale) + float(ZeroPoint); - FloatValue = std::max(FloatValue, float(std::numeric_limits::min())); - FloatValue = std::min(FloatValue, float(std::numeric_limits::max())); - OutputReference[n] = (xint8_t)FloatValue; + FloatValue = std::max(FloatValue, static_cast(std::numeric_limits::min())); + FloatValue = std::min(FloatValue, static_cast(std::numeric_limits::max())); + OutputReference[n] = static_cast(FloatValue); } } void Test(size_t N) { float* Input = BufferInput.GetBuffer(N); - xint8_t* Output = BufferOutput.GetBuffer(N); - xint8_t* OutputReference = BufferOutputReference.GetBuffer(N); + QuantInt* Output = BufferOutput.GetBuffer(N); + QuantInt* OutputReference = BufferOutputReference.GetBuffer(N); std::default_random_engine generator(static_cast(N)); @@ -34,8 +34,9 @@ class MlasQuantizeLinearTest : public MlasTestBase { float Scale = (MaximumValue - MinimumValue) / 512.f; - std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), std::numeric_limits::max()); - xint8_t ZeroPoint = static_cast(zp_distribution(generator)); + std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), + std::numeric_limits::max()); + QuantInt ZeroPoint = static_cast(zp_distribution(generator)); std::uniform_real_distribution distribution(MinimumValue, MaximumValue); for (size_t n = 0; n < N; n++) { @@ -52,8 +53,15 @@ class MlasQuantizeLinearTest : public MlasTestBase { public: static const char* GetTestSuiteName() { - static const std::string suite_name(std::is_signed::value ? "QuantizeLinearS8" : "QuantizeLinearU8"); - return suite_name.c_str(); + if constexpr (std::is_same_v) { + return "QuantizeLinearS8"; + } else if (std::is_same_v) { + return "QuantizeLinearU8"; + } else if (std::is_same_v) { + return "QuantizeLinearS16"; + } else { + return "QuantizeLinearU16"; + } } void ExecuteShort(void) override { @@ -67,12 +75,18 @@ template <> MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); template <> MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { count += MlasDirectShortExecuteTests>::RegisterShortExecute(); count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } return count; }); diff --git a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc index d0ce4898a472c..feff607703341 100644 --- a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc +++ b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc @@ -20,15 +20,17 @@ struct GraphConfig { bool has_subgraph_consumer{false}; }; -auto GetGraphBuilder(const GraphConfig& config, bool use_ms_domain_qdq_ops) { +template +std::function GetGraphBuilder(const GraphConfig& config, bool use_ms_domain_qdq_ops) { return [config, use_ms_domain_qdq_ops](ModelTestBuilder& builder) { const auto input_shape = std::vector{1, 2, 4}; constexpr float scale = 0.5f; - constexpr uint8_t zero_point = 0; + constexpr QuantType zero_point = 0; - auto* dq_input = builder.MakeInput(input_shape, uint8_t{0}, uint8_t{255}); + auto* dq_input = builder.MakeInput(input_shape, std::numeric_limits::min(), + std::numeric_limits::max()); auto* dq_output = config.has_graph_output ? builder.MakeOutput() : builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(dq_input, scale, zero_point, dq_output, use_ms_domain_qdq_ops); + builder.AddDequantizeLinearNode(dq_input, scale, zero_point, dq_output, use_ms_domain_qdq_ops); for (size_t i = 0; i < config.num_explicit_consumer_nodes; ++i) { // use Concat for the explicit consumer node as it supports a variadic number of inputs @@ -71,10 +73,12 @@ auto GetGraphBuilder(const GraphConfig& config, bool use_ms_domain_qdq_ops) { } void RunEnsureUniqueDQForNodeUnitTest(const GraphConfig& config, int expected_dq_count) { - auto run_tests = [config, expected_dq_count](bool use_ms_domain_qdq_ops) { + auto run_tests = [config, expected_dq_count](bool use_ms_domain_qdq_ops, bool use_16bit_qdq_ops) { constexpr int opset_version = 12; const char* dequantize_linear_key = use_ms_domain_qdq_ops ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; - std::function graph_builder_fn = GetGraphBuilder(config, use_ms_domain_qdq_ops); + std::function graph_builder_fn = use_16bit_qdq_ops + ? GetGraphBuilder(config, use_ms_domain_qdq_ops) + : GetGraphBuilder(config, use_ms_domain_qdq_ops); { SCOPED_TRACE("test with standalone transformer"); @@ -117,9 +121,10 @@ void RunEnsureUniqueDQForNodeUnitTest(const GraphConfig& config, int expected_dq } }; - run_tests(false); + run_tests(false, false); #if !defined(DISABLE_CONTRIB_OPS) - run_tests(true); // Use contrib QDQ ops. + run_tests(true, false); // Use contrib QDQ ops. + run_tests(true, true); // Use 16-bit contrib QDQ ops. #endif } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 553fcca92aa78..dce1f2d40e8b9 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -83,6 +83,7 @@ #include "test/util/include/test_utils.h" #include "core/optimizer/pre_shape_node_elimination.h" #include "core/optimizer/double_qdq_pairs_remover.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" #ifdef ENABLE_TRAINING #include "orttraining/core/optimizer/bitmask_dropout_replacement.h" #endif @@ -155,44 +156,43 @@ TEST_F(GraphTransformationTests, IdentityWithSharedNodeArgNotEliminated) { ASSERT_TRUE(op_to_count["Add"] == 1); } +// Runs a model to ensure that common subexpression elimination does not eliminate +// DequantizeLinear nodes. TEST_F(GraphTransformationTests, DequantizeLinearNodeNotEliminated) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "qdq_with_multi_consumer_dq_nodes.fixed.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["DequantizeLinear"], 25); + auto test_case = [](const ORTCHAR_T* model_uri, + bool use_contrib_qdq, + const logging::Logger& logger) { + const char* dq_key = use_contrib_qdq ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count[dq_key], 25); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), - TransformerLevel::Level1)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger)); - // CommonSubexpressionElimination should skip the DequantizeLinear nodes - op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["DequantizeLinear"], 25); -} + // CommonSubexpressionElimination should skip the DequantizeLinear nodes + op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count[dq_key], 25); + }; + test_case(MODEL_FOLDER "qdq_with_multi_consumer_dq_nodes.fixed.onnx", + false, // use_contrib_qdq + *logger_); #if !defined(DISABLE_CONTRIB_OPS) -// Test that com.microsoft.DequantizeLinear is not eliminated in CommonSubexpressionElimination -TEST_F(GraphTransformationTests, MsDomainDequantizeLinearNodeNotEliminated) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "qdq_with_multi_consumer_dq_nodes.fixed.qdq_contrib.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["com.microsoft.DequantizeLinear"], 25); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), - TransformerLevel::Level1)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); - - // CommonSubexpressionElimination should skip the DequantizeLinear nodes - op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["com.microsoft.DequantizeLinear"], 25); -} + // Test with 8-bit com.microsoft.DequantizeLinear + test_case(MODEL_FOLDER "qdq_with_multi_consumer_dq_nodes.fixed.qdq_contrib.onnx", + true, // use_contrib_qdq + *logger_); + // Test with 16-bit com.microsoft.DequantizeLinear + test_case(MODEL_FOLDER "qdq_with_multi_consumer_dq_nodes.fixed.qdq16_contrib.onnx", + true, // use_contrib_qdq + *logger_); #endif // !defined(DISABLE_CONTRIB_OPS) +} TEST_F(GraphTransformationTests, IdentityInputIsGraphOutputNotEliminated) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "scan9_sum.onnx"; @@ -836,158 +836,120 @@ static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_map model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["QuantizeLinear"] == 1); - ASSERT_TRUE(op_to_count["DequantizeLinear"] == 3); - ASSERT_TRUE(op_to_count["Conv"] == 1); - - std::unordered_map expected_op_counts = {{"QuantizeLinear", 1}, - {"DequantizeLinear", 3}, - {"Conv", 1}}; - - SessionOptions session_options; - // Check DequantizeLinear aren't constant folded for default setting. - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); - - // set kOrtSessionOptionsDisableQuantQDQ to enable it explicitly - ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "0")); - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); + auto test_case = [](const ORTCHAR_T* model_uri, + bool use_contrib_qdq, + const logging::Logger& logger) { + const char* q_key = use_contrib_qdq ? "com.microsoft.QuantizeLinear" : "QuantizeLinear"; + const char* dq_key = use_contrib_qdq ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; - // set SessionOptionsEnableQuantQDQ to disable it - expected_op_counts["DequantizeLinear"] = 1; - ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count[q_key] == 1); + ASSERT_TRUE(op_to_count[dq_key] == 3); + ASSERT_TRUE(op_to_count["Conv"] == 1); -#if !defined(DISABLE_CONTRIB_OPS) -// Test constant folding with a com.microsoft.DequantizeLinear node -TEST_F(GraphTransformationTests, ConstantFoldingWithMsDomainDequantizeLinear) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_dequantizelinear.qdq_contrib.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["com.microsoft.QuantizeLinear"], 1); - ASSERT_EQ(op_to_count["com.microsoft.DequantizeLinear"], 3); - ASSERT_EQ(op_to_count["Conv"], 1); + std::unordered_map expected_op_counts = {{q_key, 1}, + {dq_key, 3}, + {"Conv", 1}}; - std::unordered_map expected_op_counts = {{"com.microsoft.QuantizeLinear", 1}, - {"com.microsoft.DequantizeLinear", 3}, - {"Conv", 1}}; + SessionOptions session_options; + // Check DequantizeLinear aren't constant folded for default setting. + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, logger); - SessionOptions session_options; - // Check DequantizeLinear aren't constant folded for default setting. - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); + // set kOrtSessionOptionsDisableQuantQDQ to enable it explicitly + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "0")); + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, logger); - // set kOrtSessionOptionsDisableQuantQDQ to enable it explicitly - ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "0")); - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); + // set SessionOptionsEnableQuantQDQ to disable it + expected_op_counts[dq_key] = 1; + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, logger); + }; - // set SessionOptionsEnableQuantQDQ to disable it - expected_op_counts["com.microsoft.DequantizeLinear"] = 1; - ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + test_case(MODEL_FOLDER "fusion/constant_folding_dequantizelinear.onnx", + false, *logger_); +#if !defined(DISABLE_CONTRIB_OPS) + // Test with 8-bit contrib QDQ ops + test_case(MODEL_FOLDER "fusion/constant_folding_dequantizelinear.qdq_contrib.onnx", + true, *logger_); + // Test with 16-bit contrib QDQ ops + test_case(MODEL_FOLDER "fusion/constant_folding_dequantizelinear.qdq16_contrib.onnx", + true, *logger_); #endif // !defined(DISABLE_CONTRIB_OPS) +} // model with 2 QDQ node units that can be constant folded as they are simple DQ -> Node -> Q where DQ and Node have // single consumer and do not produce graph outputs. Node is deterministic. // there are also other DQ nodes that should be ignored. TEST_F(GraphTransformationTests, ConstantFoldingQDQNodeUnit) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["QuantizeLinear"] == 3); - ASSERT_TRUE(op_to_count["DequantizeLinear"] == 4); - ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); - ASSERT_TRUE(op_to_count["Transpose"] == 1); + auto test_case = [](const ORTCHAR_T* model_uri, bool use_contrib_qdq, const logging::Logger& logger) { + const char* q_key = use_contrib_qdq ? "com.microsoft.QuantizeLinear" : "QuantizeLinear"; + const char* dq_key = use_contrib_qdq ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; - SessionOptions session_options; - - // 2 QDQ node units should be constant folded and go away - std::unordered_map expected_op_counts = {{"QuantizeLinear", 1}, - {"DequantizeLinear", 2}, - {"Transpose", 0}, - {"Unsqueeze", 0}}; - - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count[q_key] == 3); + ASSERT_TRUE(op_to_count[dq_key] == 4); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); + ASSERT_TRUE(op_to_count["Transpose"] == 1); -#if !defined(DISABLE_CONTRIB_OPS) -// model with 2 (com.microsoft) QDQ node units that can be constant folded as they are simple DQ -> Node -> Q where -// DQ and Node have single consumer and do not produce graph outputs. Node is deterministic. -// there are also other DQ nodes that should be ignored. -TEST_F(GraphTransformationTests, ConstantFoldingMsDomainQDQNodeUnit) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.qdq_contrib.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["com.microsoft.QuantizeLinear"], 3); - ASSERT_EQ(op_to_count["com.microsoft.DequantizeLinear"], 4); - ASSERT_EQ(op_to_count["Unsqueeze"], 1); - ASSERT_EQ(op_to_count["Transpose"], 1); + SessionOptions session_options; - SessionOptions session_options; + // 2 QDQ node units should be constant folded and go away + std::unordered_map expected_op_counts = {{q_key, 1}, + {dq_key, 2}, + {"Transpose", 0}, + {"Unsqueeze", 0}}; - // 2 QDQ node units should be constant folded and go away - std::unordered_map expected_op_counts = {{"com.microsoft.QuantizeLinear", 1}, - {"com.microsoft.DequantizeLinear", 2}, - {"Transpose", 0}, - {"Unsqueeze", 0}}; + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, logger); + }; - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.onnx", false, *logger_); +#if !defined(DISABLE_CONTRIB_OPS) + // Test with 8-bit com.microsoft.Q/DQ + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.qdq_contrib.onnx", true, *logger_); + // Test with 16-bit com.microsoft.Q/DQ + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.qdq16_contrib.onnx", true, *logger_); #endif // !defined(DISABLE_CONTRIB_OPS) +} // Simple QDQ Node Unit but shouldn't be constant folded as the node in the middle produces a graph output TEST_F(GraphTransformationTests, ConstantFoldingQDQNodeUnitGraphOutput) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["QuantizeLinear"] == 2); - ASSERT_TRUE(op_to_count["DequantizeLinear"] == 3); - ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); + auto test_case = [](const ORTCHAR_T* model_uri, bool use_contrib_qdq, const logging::Logger& logger) { + const char* q_key = use_contrib_qdq ? "com.microsoft.QuantizeLinear" : "QuantizeLinear"; + const char* dq_key = use_contrib_qdq ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; - std::unordered_map expected_op_counts = {{"QuantizeLinear", 2}, - {"DequantizeLinear", 3}, - {"Unsqueeze", 1}}; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count[q_key] == 2); + ASSERT_TRUE(op_to_count[dq_key] == 3); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); - SessionOptions session_options; - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + std::unordered_map expected_op_counts = {{q_key, 2}, + {dq_key, 3}, + {"Unsqueeze", 1}}; -#if !defined(DISABLE_CONTRIB_OPS) -// Simple (com.microsoft) QDQ Node Unit but shouldn't be constant folded as the node in the middle produces a -// graph output -TEST_F(GraphTransformationTests, ConstantFoldingMsDomainQDQNodeUnitGraphOutput) { - constexpr const ORTCHAR_T* model_uri = - MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.qdq_contrib.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["com.microsoft.QuantizeLinear"], 2); - ASSERT_EQ(op_to_count["com.microsoft.DequantizeLinear"], 3); - ASSERT_EQ(op_to_count["Unsqueeze"], 1); + SessionOptions session_options; + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, logger); + }; - std::unordered_map expected_op_counts = {{"com.microsoft.QuantizeLinear", 2}, - {"com.microsoft.DequantizeLinear", 3}, - {"Unsqueeze", 1}}; + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.onnx", false, *logger_); +#if !defined(DISABLE_CONTRIB_OPS) + // Test with 8-bit contrib QDQ ops + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.qdq_contrib.onnx", true, *logger_); - SessionOptions session_options; - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + // Test with 16-bit contrib QDQ ops + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.qdq16_contrib.onnx", true, *logger_); #endif // !defined(DISABLE_CONTRIB_OPS) +} TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConstantFoldedNode) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_remove_dangling_inputs.onnx"; @@ -3898,12 +3860,12 @@ TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ) { std::string zp_name_after_reshape_node; for (auto& node : graph.Nodes()) { if (node.Name() == "dq_2") { - dq_scale_name_before_reshape_node = node.InputDefs()[InputIndex::SCALE_ID]->Name(); - zp_name_before_reshape_node = node.InputDefs()[InputIndex::ZERO_POINT_ID]->Name(); + dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); + zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); } if (node.Name() == "q_3") { - dq_scale_name_after_reshape_node = node.InputDefs()[InputIndex::SCALE_ID]->Name(); - zp_name_after_reshape_node = node.InputDefs()[InputIndex::ZERO_POINT_ID]->Name(); + dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); + zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); } } EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false); diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 743faee3ee2a5..63577131480c6 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -39,9 +39,21 @@ namespace test { template struct IsTypeQuantLinearCompatible : utils::IsByteType {}; +template <> +struct IsTypeQuantLinearCompatible : std::true_type {}; + +template <> +struct IsTypeQuantLinearCompatible : std::true_type {}; + template struct IsTypeDequantLinearCompatible : utils::IsByteType {}; +template <> +struct IsTypeDequantLinearCompatible : std::true_type {}; + +template <> +struct IsTypeDequantLinearCompatible : std::true_type {}; + template <> struct IsTypeDequantLinearCompatible : std::true_type {}; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 0dfeb599d0ae3..a438a61cb9b36 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -891,37 +891,139 @@ TEST(QDQTransformerTests, Gemm_S8S8U8) { QDQTransformerGemmTests(); } +// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> Gather -> Q. +template +static void RunGatherDropQDQTestCase(const std::vector& input1_shape, + const std::vector& weights_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [input1_shape, weights_shape, use_contrib_qdq](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, 0, weights_shape[0] - 1); + auto* output_arg = builder.MakeOutput(); + + // add Gather + auto* weight = builder.MakeInitializer(weights_shape, std::numeric_limits::min(), + std::numeric_limits::max()); + auto* dq_w_output = builder.MakeIntermediate(); + auto* gather_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(weight, .003f, 1, dq_w_output, use_contrib_qdq); + builder.AddNode("Gather", {dq_w_output, input1_arg}, {gather_output}); + + // add Q + builder.AddQuantizeLinearNode(gather_output, .003f, 1, output_arg, use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["Gather"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); +} + +// Checks that Q/DQ nodes are dropped from DQ -> Gather -> Q. Uses 8-bit and 16-bit Q/DQ ops. TEST(QDQTransformerTests, Gather) { - auto test_case = [&](const std::vector& input1_shape, const std::vector& weights_shape, - bool use_contrib_qdq = false) { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput(input1_shape, 0, weights_shape[0] - 1); - auto* output_arg = builder.MakeOutput(); + RunGatherDropQDQTestCase({12, 37}, {24, 12}); + RunGatherDropQDQTestCase({12, 37}, {24, 12}, true); // Use com.microsoft QDQ ops + RunGatherDropQDQTestCase({12, 37}, {24, 12}, true); // Use int16 com.microsoft QDQ ops +} - // add Gather - auto* weight = builder.MakeInitializer(weights_shape, -128, 127); - auto* dq_w_output = builder.MakeIntermediate(); - auto* gather_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(weight, .003f, 1, dq_w_output, use_contrib_qdq); - builder.AddNode("Gather", {dq_w_output, input1_arg}, {gather_output}); +// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> Reshape -> Q. +template +static void RunReshapeDropQDQTestCase(const std::vector& input_shape, + const std::vector& new_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [input_shape, new_shape, use_contrib_qdq](ModelTestBuilder& builder) { + constexpr QuantType qmin = std::numeric_limits::min(); + constexpr QuantType qmax = std::numeric_limits::max(); + + auto* input_arg = builder.MakeInput(input_shape, qmin, qmax); + auto* output_arg = builder.MakeOutput(); + QuantType zero_point = 1 + (qmax + qmin) / 2; + + // Add Reshape node + auto* new_shape_arg = builder.Make1DInitializer(new_shape); + auto* input_arg_dq = builder.MakeIntermediate(); + auto* reshape_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, .003f, zero_point, input_arg_dq, use_contrib_qdq); + builder.AddNode("Reshape", {input_arg_dq, new_shape_arg}, {reshape_output}); + + // add Q + builder.AddQuantizeLinearNode(reshape_output, .003f, zero_point, output_arg, use_contrib_qdq); + }; - // add Q - builder.AddQuantizeLinearNode(gather_output, .003f, 1, output_arg, use_contrib_qdq); - }; + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["Reshape"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; - auto check_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count["Gather"], 1); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); - }; + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); +} - TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); +// Checks that Q/DQ nodes are dropped from DQ -> Reshape -> Q. Uses 8-bit and 16-bit Q/DQ ops. +TEST(QDQTransformerTests, ReshapeDropQDQ) { + RunReshapeDropQDQTestCase({1, 3, 2, 2}, {1, 12}); + RunReshapeDropQDQTestCase({1, 3, 2, 2}, {1, 12}, true); // Use com.microsoft QDQ ops + RunReshapeDropQDQTestCase({1, 3, 2, 2}, {1, 12}, true); // Use int16 com.microsoft QDQ ops + RunReshapeDropQDQTestCase({1, 3, 2, 2}, {1, 12}, true); // Use int16 com.microsoft QDQ ops +} + +// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> (Un)Squeeze -> Q. +template +static void RunSqueezeUnsqueezeDropQDQTestCase(const std::string& squeeze_type, + const std::vector& input_shape, + const std::vector& axes, + bool use_contrib_qdq = false) { + auto build_test_case = [squeeze_type, input_shape, axes, use_contrib_qdq](ModelTestBuilder& builder) { + constexpr QuantType qmin = std::numeric_limits::min(); + constexpr QuantType qmax = std::numeric_limits::max(); + + auto* input_arg = builder.MakeInput(input_shape, qmin, qmax); + auto* output_arg = builder.MakeOutput(); + QuantType zero_point = 1 + (qmax + qmin) / 2; + + // Add Squeeze node + auto* axes_arg = builder.Make1DInitializer(axes); + auto* input_arg_dq = builder.MakeIntermediate(); + auto* xsqueeze_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, .003f, zero_point, input_arg_dq, use_contrib_qdq); + builder.AddNode(squeeze_type, {input_arg_dq, axes_arg}, {xsqueeze_output}); + + // add Q + builder.AddQuantizeLinearNode(xsqueeze_output, .003f, zero_point, output_arg, use_contrib_qdq); }; - test_case({12, 37}, {24, 12}); - test_case({12, 37}, {24, 12}, true); // Use com.microsoft QDQ ops + auto check_graph = [squeeze_type, use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[squeeze_type], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 13 /* opset_version */); +} + +// Checks that Q/DQ nodes are dropped from DQ -> Squeeze -> Q. Uses 8-bit and 16-bit Q/DQ ops. +TEST(QDQTransformerTests, SqueezeDropQDQ) { + RunSqueezeUnsqueezeDropQDQTestCase("Squeeze", {1, 3, 2, 2}, {0}); + RunSqueezeUnsqueezeDropQDQTestCase("Squeeze", {1, 3, 2, 2}, {0}, true); // Use MS domain QDQ ops + RunSqueezeUnsqueezeDropQDQTestCase("Squeeze", {1, 3, 2, 2}, {0}, true); // Use int16 MS domain QDQ ops + RunSqueezeUnsqueezeDropQDQTestCase("Squeeze", {1, 3, 2, 2}, {0}, true); // Use int16 MS domain QDQ ops +} + +// Checks that Q/DQ nodes are dropped from DQ -> Unsqueeze -> Q. Uses 8-bit and 16-bit Q/DQ ops. +TEST(QDQTransformerTests, UnsqueezeDropQDQ) { + RunSqueezeUnsqueezeDropQDQTestCase("Unsqueeze", {1, 3, 2, 2}, {0}); + RunSqueezeUnsqueezeDropQDQTestCase("Unsqueeze", {1, 3, 2, 2}, {0}, true); // Use MS domain QDQ ops + RunSqueezeUnsqueezeDropQDQTestCase("Unsqueeze", {1, 3, 2, 2}, {0}, true); // Use int16 MS domain QDQ ops + RunSqueezeUnsqueezeDropQDQTestCase("Unsqueeze", {1, 3, 2, 2}, {0}, true); // Use int16 MS domain QDQ ops } TEST(QDQTransformerTests, DoubleQDQ) { @@ -1066,52 +1168,69 @@ TEST(QDQTransformerTests, DoubleQDQ) { bad_float_point, good_float_point_2, true); // Use com.microsoft QDQ ops } -TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { - auto test_case = [&](int output_index, int expected_Q_count, int expected_DQ_count, - bool use_contrib_qdq = false) { - auto graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_Q_count); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_DQ_count); - }; - TransformerTester( - BuildDoubleQDQWithoutLastOutput(output_index, use_contrib_qdq), - graph, - TransformerLevel::Default, - TransformerLevel::Level1); +template +static void RunDoubleQDQWithoutLastNodeBeingOutput(int output_index, int expected_Q_count, int expected_DQ_count, + bool use_contrib_qdq = false) { + auto graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_Q_count); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_DQ_count); }; + TransformerTester( + BuildDoubleQDQWithoutLastOutput(output_index, use_contrib_qdq), + graph, + TransformerLevel::Default, + TransformerLevel::Level1); +} + +TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { constexpr bool use_contrib_qdq = true; // For readability. - test_case(0, 2, 2); - test_case(0, 2, 2, use_contrib_qdq); - test_case(1, 2, 3); // EnsureUniqueDQForNodeUnit will duplicate first DQ, so expect one more (3) - test_case(1, 2, 3, use_contrib_qdq); // EnsureUniqueDQForNodeUnit will duplicate first DQ, so expect one more (3) - test_case(2, 2, 2); - test_case(2, 2, 2, use_contrib_qdq); - test_case(3, 1, 1); - test_case(3, 1, 1, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2); + RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2, use_contrib_qdq); + + // EnsureUniqueDQForNodeUnit will duplicate first DQ, so expected one more (3) + RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3); + RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3, use_contrib_qdq); + + RunDoubleQDQWithoutLastNodeBeingOutput(2, 2, 2); + RunDoubleQDQWithoutLastNodeBeingOutput(2, 2, 2, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(2, 2, 2, use_contrib_qdq); + + RunDoubleQDQWithoutLastNodeBeingOutput(3, 1, 1); + RunDoubleQDQWithoutLastNodeBeingOutput(3, 1, 1, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(3, 1, 1, use_contrib_qdq); +} + +// Runs a test that checks if DQ -> Split -> Q (many) is replaced with just Split. +template +static void RunDropSplitQDQTestCase(const std::vector& input_shape, int64_t axis, + bool use_contrib_qdq = false) { + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["Split"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + TransformerTester(BuildQDQSplitTestCase(input_shape, axis, use_contrib_qdq), + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + {12, 18, 19}); } -// Because split isn't one the supported ops, this will stay the same +// Test that DQ -> Split -> Q (many) is replaced with just Split for various quantization types. TEST(QDQTransformerTests, Split) { - auto test_case = [&](const std::vector& input_shape, const int64_t& axis, - bool use_contrib_qdq = false) { - auto check_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count["Split"], 1); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); - }; - TransformerTester(BuildQDQSplitTestCase(input_shape, axis, use_contrib_qdq), - check_graph, - TransformerLevel::Level1, - TransformerLevel::Level2, - {12, 18, 19}); - }; - test_case({6, 18, 54}, 0); - test_case({6, 18, 54}, 0, true); // Use com.microsoft QDQ ops + RunDropSplitQDQTestCase({6, 18, 54}, 0); + RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft int8 QDQ ops + RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft int16 QDQ ops + RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft uint16 QDQ ops } // Because split isn't one the supported ops, this will stay the same @@ -1174,59 +1293,66 @@ TEST(QDQTransformerTests, Where) { test_case({1}, {1}, {1}, true /*use_contrib_qdq*/); } -TEST(QDQTransformerTests, Transpose) { - auto test_case = [&](const std::vector& input_shape, const std::vector& perms, - bool use_contrib_qdq = false) { - auto check_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count["Transpose"], 1); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); - }; - - TransformerTester(BuildQDQTransposeTestCase(input_shape, perms, use_contrib_qdq), - check_graph, - TransformerLevel::Level1, - TransformerLevel::Level2); +template +static void RunDropQDQTransposeTestCase(const std::vector& input_shape, const std::vector& perms, + bool use_contrib_qdq = false) { + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["Transpose"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); }; - test_case({2, 13, 12, 37}, {0, 3, 1, 2}); - test_case({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + TransformerTester(BuildQDQTransposeTestCase(input_shape, perms, use_contrib_qdq), + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2); } -TEST(QDQTransformerTests, Transpose_No_Fusion) { - auto test_case = [&](const std::vector& input1_shape, const std::vector& perms, - bool use_contrib_qdq = false) { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput(input1_shape, -128, 127); - auto* output_arg = builder.MakeOutput(); - - // add DQ - auto* dq_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(input1_arg, .003f, 1, dq_output, use_contrib_qdq); - - // add Transpose - auto* transpose_output = builder.MakeOutput(); // transpose output is graph output - Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {transpose_output}); - transpose_node.AddAttribute("perm", perms); - - // add Q - builder.AddQuantizeLinearNode(transpose_output, .003f, 1, output_arg, use_contrib_qdq); - }; - - auto check_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); - }; +TEST(QDQTransformerTests, Transpose) { + RunDropQDQTransposeTestCase({2, 13, 12, 37}, {0, 3, 1, 2}); + RunDropQDQTransposeTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + RunDropQDQTransposeTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + RunDropQDQTransposeTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); +} + +template +static void RunQDQTransposeNoFusionTestCase(const std::vector& input1_shape, const std::vector& perms, + bool use_contrib_qdq = false) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input1_arg, .003f, 1, dq_output, use_contrib_qdq); + + // add Transpose + auto* transpose_output = builder.MakeOutput(); // transpose output is graph output + Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {transpose_output}); + transpose_node.AddAttribute("perm", perms); + + // add Q + builder.AddQuantizeLinearNode(transpose_output, .003f, 1, output_arg, use_contrib_qdq); + }; - TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); }; - test_case({2, 13, 12, 37}, {0, 3, 1, 2}); - test_case({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); +} + +TEST(QDQTransformerTests, Transpose_No_Fusion) { + RunQDQTransposeNoFusionTestCase({2, 13, 12, 37}, {0, 3, 1, 2}); + RunQDQTransposeNoFusionTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + RunQDQTransposeNoFusionTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + RunQDQTransposeNoFusionTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); } TEST(QDQTransformerTests, Resize) { @@ -1376,50 +1502,59 @@ TEST(QDQTransformerTests, ResizeReshapeSqueezeUnsqueeze) { test_case({1, 2, 26, 42}, {4}, true /*use_contrib_qdq*/); } -TEST(QDQTransformerTests, ArgMax) { - auto test_case = [&](const std::vector& input_shape, - int axis, - int keepdims, - int select_last_index, - bool use_contrib_qdq) { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* input_arg = builder.MakeInput(input_shape, - std::numeric_limits::min(), - std::numeric_limits::max()); - auto* output_arg = builder.MakeOutput(); +// Runs a test case that checks if the DQ node is dropped from DQ -> Op (e.g., ArgMax). +template +static void RunArgMaxDropDQTestCase(const std::vector& input_shape, + int axis, + int keepdims, + int select_last_index, + bool use_contrib_qdq, + bool expect_drop_dq = true) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, .003f, 1, dq_output, use_contrib_qdq); + + // add ArgMax + Node& argmax_node = builder.AddNode("ArgMax", {dq_output}, {output_arg}); + argmax_node.AddAttribute("axis", static_cast(axis)); + argmax_node.AddAttribute("keepdims", static_cast(keepdims)); + argmax_node.AddAttribute("select_last_index", static_cast(select_last_index)); + }; - // add DQ - auto* dq_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(input_arg, .003f, 1, dq_output, use_contrib_qdq); + auto check_graph = [use_contrib_qdq, expect_drop_dq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["ArgMax"], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expect_drop_dq ? 0 : 1); + }; - // add ArgMax - Node& argmax_node = builder.AddNode("ArgMax", {dq_output}, {output_arg}); - argmax_node.AddAttribute("axis", static_cast(axis)); - argmax_node.AddAttribute("keepdims", static_cast(keepdims)); - argmax_node.AddAttribute("select_last_index", static_cast(select_last_index)); - }; + TransformerTester(build_test_case, check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + /* opset_version */ 13); + TransformerTester(build_test_case, check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + /* opset_version */ 19); +} - auto check_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count["ArgMax"], 1); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); - }; +// Checks that the DQ node is dropped from DQ -> ArgMax. Uses 8-bit and 16-bit Q/DQ ops. +TEST(QDQTransformerTests, ArgMax) { + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 1, 0, 0, false); + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 1, 0, 0, true /*use_contrib_qdq*/); - TransformerTester(build_test_case, check_graph, - TransformerLevel::Level1, - TransformerLevel::Level2, - /* opset_version */ 13); - TransformerTester(build_test_case, check_graph, - TransformerLevel::Level1, - TransformerLevel::Level2, - /* opset_version */ 19); - }; + // Should *not* drop DQ for 16-bit DQ -> ArgMax (because ORT does not support 16-bit input types for ArgMax). + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 1, 0, 0, true /*use_contrib_qdq*/, false /*expect_drop_dq*/); + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 1, 0, 0, true /*use_contrib_qdq*/, false /*expect_drop_dq*/); - test_case({2, 13, 12, 37}, 1, 0, 0, false /*use_contrib_qdq*/); - test_case({2, 13, 12, 37}, 1, 0, 0, true /*use_contrib_qdq*/); - test_case({2, 13, 12, 37}, 0, 1, 0, false /*use_contrib_qdq*/); - test_case({2, 13, 12, 37}, 0, 0, 1, false /*use_contrib_qdq*/); + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 0, 1, 0, false); + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 0, 0, 1, false); } TEST(QDQTransformerTests, QLinearMatMul) { diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index e5aa36fc379f4..1f4c499985ad0 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include @@ -8,6 +9,7 @@ #include "gmock/gmock.h" #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "core/framework/op_node_proto_helper.h" #include "core/framework/utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -3501,150 +3503,116 @@ TEST(TransposeOptimizerTests, TestWhere) { /*opset_version*/ {15, 18}); } -TEST(TransposeOptimizerTests, TestQuantizeLinearScalar) { - auto test_case = [&](const std::string& q_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0.0, 1.0); - auto* input1_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); - auto* input2_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* quantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - builder.AddNode("QuantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, {quantizelinear_1_out_0}, - q_domain); - auto& transpose_2 = builder.AddNode("Transpose", {quantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; +// Utility function that runs TransformerTester for the graph Transpose -> QuantizeLinear -> Transpose. +// Expects the Tranpose nodes to cancel. +template +static void RunQuantizeLinearTestCase(const std::optional>& zp_input_shape, + const std::vector& zp_value_shape, + std::optional axis, + const std::string& q_domain = "") { + auto build_test_case = [&](ModelTestBuilder& builder) { + constexpr QuantType qmin = std::numeric_limits::min(); + constexpr QuantType qmax = std::numeric_limits::max(); - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0.0, 1.0); + + NodeArg* scale_arg = nullptr; + NodeArg* zero_point_arg = nullptr; + + if (zp_value_shape.empty()) { // Per-tensor quantization + QuantType zp = (qmax + qmin) / 2; + scale_arg = MakeInput(builder, zp_input_shape, zp_value_shape, {0.05f}); + zero_point_arg = MakeInput(builder, zp_input_shape, zp_value_shape, {zp}); + } else { // Per-axis quantization + scale_arg = MakeInput(builder, zp_input_shape, zp_value_shape, 0.0f, 1.0f); + zero_point_arg = MakeInput(builder, zp_input_shape, zp_value_shape, qmin, qmax); + } + auto* transpose_1_out_0 = builder.MakeIntermediate(); + auto* quantizelinear_1_out_0 = builder.MakeIntermediate(); + auto* transpose_2_out_0 = builder.MakeOutput(); + + auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); + transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); + auto& quantizelinear_1 = builder.AddNode("QuantizeLinear", {transpose_1_out_0, scale_arg, zero_point_arg}, + {quantizelinear_1_out_0}, q_domain); - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); + if (axis.has_value()) { + quantizelinear_1.AddAttributeProto(*axis); + } + + auto& transpose_2 = builder.AddNode("Transpose", {quantizelinear_1_out_0}, {transpose_2_out_0}); + transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); + }; + + auto check_optimized_graph = [](InferenceSessionWrapper& session) { + int transpose_cost = EstimateTransposeCost(session.GetGraph()); + EXPECT_EQ(transpose_cost, 0); }; - test_case(); + TransformerTester(build_test_case, + check_optimized_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + /*opset_version*/ {15, 18}); +} + +TEST(TransposeOptimizerTests, TestQuantizeLinearScalar) { + std::optional> zp_input_shape = std::vector{}; + std::vector zp_value_shape{}; + std::optional empty_axis; // No axis value. + + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, empty_axis, kOnnxDomain); + #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.QuantizeLinear + // Use com.microsoft.QuantizeLinear op. + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, empty_axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, empty_axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, empty_axis, kMSDomain); #endif } TEST(TransposeOptimizerTests, TestQuantizeLinearScalarIgnoreAxis) { - auto test_case = [&](const std::string& q_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0.0, 1.0); - auto* input1_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); - auto* input2_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* quantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - auto& quantizelinear_1 = builder.AddNode("QuantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, - {quantizelinear_1_out_0}, q_domain); - quantizelinear_1.AddAttribute("axis", (int64_t)10); - auto& transpose_2 = builder.AddNode("Transpose", {quantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; - - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + std::optional> zp_input_shape = std::vector{}; + std::vector zp_value_shape{}; + auto ignored_axis = utils::MakeAttribute("axis", static_cast(10)); // Should be ignored for per-tensor Q - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); - }; + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kOnnxDomain); - test_case(); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.QuantizeLinear + // Use com.microsoft.QuantizeLinear op. + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); #endif } TEST(TransposeOptimizerTests, TestQuantizeLinearVector) { - auto test_case = [&](const std::string& q_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0.0, 1.0); - auto* input1_arg = MakeInput(builder, {{-1}}, {2}, {2.3f, 2.4f}); - auto* input2_arg = MakeInput(builder, {{-1}}, {2}, {10, 12}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* quantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - auto& quantizelinear_1 = builder.AddNode("QuantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, - {quantizelinear_1_out_0}, q_domain); - quantizelinear_1.AddAttribute("axis", (int64_t)0); - auto& transpose_2 = builder.AddNode("Transpose", {quantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; + std::optional> zp_input_shape = std::vector{-1}; + std::vector zp_value_shape = {2}; + auto axis = utils::MakeAttribute("axis", static_cast(0)); - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kOnnxDomain); - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); - }; - - test_case(); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.QuantizeLinear + // Use com.microsoft.QuantizeLinear op. + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); #endif } TEST(TransposeOptimizerTests, TestQuantizeLinearVectorUnknownRank) { - auto test_case = [&](const std::string& q_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0.0, 1.0); - auto* input1_arg = MakeInput(builder, std::nullopt, {3}, {2.3f, 2.4f, 2.5f}); - auto* input2_arg = MakeInput(builder, std::nullopt, {3}, {10, 12, 13}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* quantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - auto& quantizelinear_1 = builder.AddNode("QuantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, - {quantizelinear_1_out_0}, q_domain); - quantizelinear_1.AddAttribute("axis", (int64_t)1); - auto& transpose_2 = builder.AddNode("Transpose", {quantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; + std::optional> zp_unknown_shape; // Empty shape + std::vector zp_value_shape = {3}; + auto axis = utils::MakeAttribute("axis", static_cast(1)); - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + RunQuantizeLinearTestCase(zp_unknown_shape, zp_value_shape, axis, kOnnxDomain); - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); - }; - - test_case(); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.QuantizeLinear + // Use com.microsoft.QuantizeLinear op. + RunQuantizeLinearTestCase(zp_unknown_shape, zp_value_shape, axis, kMSDomain); + RunQuantizeLinearTestCase(zp_unknown_shape, zp_value_shape, axis, kMSDomain); + RunQuantizeLinearTestCase(zp_unknown_shape, zp_value_shape, axis, kMSDomain); #endif } @@ -3676,158 +3644,158 @@ TEST(TransposeOptimizerTests, TestQuantizeLinearScalarOpset10) { /*opset_version*/ 10); } -TEST(TransposeOptimizerTests, TestDequantizeLinearScalarIgnoreAxis) { - auto test_case = [&](const std::string& dq_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5); - auto* input1_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); - auto* input2_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* dequantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - auto& dequantizelinear_1 = builder.AddNode("DequantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, - {dequantizelinear_1_out_0}, dq_domain); - dequantizelinear_1.AddAttribute("axis", (int64_t)10); - auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; +// Utility function that runs TransformerTester for the graph Transpose -> DequantizeLinear -> Transpose. +// Expects the Tranpose nodes to cancel. +template +static void RunDequantizeLinearTestCase(const std::optional>& zp_input_shape, + const std::vector& zp_value_shape, + std::optional axis, + const std::string& q_domain = "") { + auto build_test_case = [&](ModelTestBuilder& builder) { + constexpr QuantType qmin = std::numeric_limits::min(); + constexpr QuantType qmax = std::numeric_limits::max(); - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, qmin, qmax); + + NodeArg* scale_arg = nullptr; + NodeArg* zero_point_arg = nullptr; + + if (zp_value_shape.empty()) { // Per-tensor quantization + QuantType zp = (qmax + qmin) / 2; + scale_arg = MakeInput(builder, zp_input_shape, zp_value_shape, {0.05f}); + zero_point_arg = MakeInput(builder, zp_input_shape, zp_value_shape, {zp}); + } else { // Per-axis quantization + scale_arg = MakeInput(builder, zp_input_shape, zp_value_shape, 0.0f, 1.0f); + zero_point_arg = MakeInput(builder, zp_input_shape, zp_value_shape, qmin, qmax); + } + auto* transpose_1_out_0 = builder.MakeIntermediate(); + auto* dequantizelinear_1_out_0 = builder.MakeIntermediate(); + auto* transpose_2_out_0 = builder.MakeOutput(); + + auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); + transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); + auto& dequantizelinear_1 = builder.AddNode("DequantizeLinear", {transpose_1_out_0, scale_arg, zero_point_arg}, + {dequantizelinear_1_out_0}, q_domain); + + if (axis.has_value()) { + dequantizelinear_1.AddAttributeProto(*axis); + } + + auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0}); + transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); + }; - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); + auto check_optimized_graph = [](InferenceSessionWrapper& session) { + int transpose_cost = EstimateTransposeCost(session.GetGraph()); + EXPECT_EQ(transpose_cost, 0); }; - test_case(); + TransformerTester(build_test_case, + check_optimized_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + /*opset_version*/ {15, 18}); +} + +TEST(TransposeOptimizerTests, TestDequantizeLinearScalarIgnoreAxis) { + std::optional> zp_input_shape = std::vector{}; + std::vector zp_value_shape{}; + auto ignored_axis = utils::MakeAttribute("axis", static_cast(10)); // Should be ignored for per-tensor Q + + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kOnnxDomain); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.DequantizeLinear + // Use com.microsoft.DequantizeLinear ops + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); #endif } TEST(TransposeOptimizerTests, TestDequantizeLinearVector) { - auto test_case = [&](const std::string& dq_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5); - auto* input1_arg = MakeInput(builder, {{2}}, {2}, {2.3f, 2.4f}); - auto* input2_arg = MakeInput(builder, {{2}}, {2}, {10, 12}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* dequantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - auto& dequantizelinear_1 = builder.AddNode("DequantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, - {dequantizelinear_1_out_0}, dq_domain); - dequantizelinear_1.AddAttribute("axis", (int64_t)-4); - auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; + std::optional> zp_input_shape = std::vector{2}; + std::vector zp_value_shape = {2}; + auto axis = utils::MakeAttribute("axis", static_cast(-4)); - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kOnnxDomain); +#if !defined(DISABLE_CONTRIB_OPS) + // Use com.microsoft.DequantizeLinear ops + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); +#endif +} - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); - }; +TEST(TransposeOptimizerTests, TestDequantizeLinearNoAxis) { + std::optional> zp_input_shape = std::vector{}; + std::vector zp_value_shape{}; + std::optional no_axis; // Empty axis value will not be set. - test_case(); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, no_axis, kOnnxDomain); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.DequantizeLinear + // Use com.microsoft.DequantizeLinear ops + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, no_axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, no_axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, no_axis, kMSDomain); #endif } -TEST(TransposeOptimizerTests, TestDequantizeLinearNoAxis) { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5); - auto* input1_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); - auto* input2_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); +// Utility function that runs TransformerTester for the graph in which a single DequantizeLinear node is +// the parent of two Transpose nodes. The DQ should be duplicated by EnsureUniqueDQForNodeUnit, and the +// Transposes should be pushed. +template +static void RunDequantizeLinearTransposePropagationTestCase(const std::string& dq_domain = "") { + auto build_test_case = [dq_domain](ModelTestBuilder& builder) { + auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5); + auto* scale_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); + auto* zero_point_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); auto* dequantizelinear_1_out_0 = builder.MakeIntermediate(); + auto* transpose_1_out_0 = builder.MakeOutput(); auto* transpose_2_out_0 = builder.MakeOutput(); - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); + builder.AddNode("DequantizeLinear", {input0_arg, scale_arg, zero_point_arg}, {dequantizelinear_1_out_0}, + dq_domain); + + auto& transpose_1 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_1_out_0}); transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - builder.AddNode("DequantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, {dequantizelinear_1_out_0}); + auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0}); transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); }; - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); + auto check_graph = [dq_domain](InferenceSessionWrapper& session) { + const auto& graph = session.GetGraph(); + + const char* dq_count_key = (dq_domain == kMSDomain) ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; + const auto op_count = CountOpsInGraph(graph); + decltype(op_count) expected_op_count{ + {dq_count_key, 2}, // EnsureUniqueDQForNodeUnit should duplicate the original DQ + {"Transpose", 2}, + }; + ASSERT_EQ(op_count, expected_op_count); + + // Transposes should be pushed, so check for Transpose -> DQ edges + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Transpose") { + ASSERT_EQ(node.GetOutputEdgesCount(), static_cast(1)); + ASSERT_EQ(node.OutputEdgesBegin()->GetNode().OpType(), "DequantizeLinear"); + } + } }; - TransformerTester(build_test_case_1, - check_optimized_graph_1, + TransformerTester(build_test_case, + check_graph, TransformerLevel::Default, TransformerLevel::Level1, /*opset_version*/ 10); } TEST(TransposeOptimizerTests, TestDequantizeLinearTransposePropagation) { - auto test_case = [&](const std::string& dq_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5); - auto* input1_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); - auto* input2_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); - auto* dequantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_1_out_0 = builder.MakeOutput(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - builder.AddNode("DequantizeLinear", {input0_arg, input1_arg, input2_arg}, {dequantizelinear_1_out_0}, - dq_domain); - - auto& transpose_1 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - - auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; - - auto check_graph = [&](InferenceSessionWrapper& session) { - const auto& graph = session.GetGraph(); - - const char* dq_count_key = (dq_domain == kMSDomain) ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; - const auto op_count = CountOpsInGraph(graph); - decltype(op_count) expected_op_count{ - {dq_count_key, 2}, // EnsureUniqueDQForNodeUnit should duplicate the original DQ - {"Transpose", 2}, - }; - ASSERT_EQ(op_count, expected_op_count); - - // Transposes should be pushed, so check for Transpose -> DQ edges - for (const auto& node : graph.Nodes()) { - if (node.OpType() == "Transpose") { - ASSERT_EQ(node.GetOutputEdgesCount(), static_cast(1)); - ASSERT_EQ(node.OutputEdgesBegin()->GetNode().OpType(), "DequantizeLinear"); - } - } - }; - - TransformerTester(build_test_case_1, - check_graph, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ 10); - }; - - test_case(); + RunDequantizeLinearTransposePropagationTestCase(); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.DequantizeLinear + // Use com.microsoft.DequantizeLinear + RunDequantizeLinearTransposePropagationTestCase(kMSDomain); + RunDequantizeLinearTransposePropagationTestCase(kMSDomain); + RunDequantizeLinearTransposePropagationTestCase(kMSDomain); #endif } diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index e9e285411f0a7..0549051bc2387 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -21,7 +21,8 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons const std::vector& pads, const std::vector& dilations, const std::string& auto_pad = "NOTSET") { - return [conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, auto_pad](ModelTestBuilder& builder) { + return [conv_op_type, input_def, weights_def, bias_def, strides, pads, + dilations, auto_pad](ModelTestBuilder& builder) { std::vector conv_inputs = { MakeTestInput(builder, input_def), MakeTestInput(builder, weights_def)}; @@ -77,29 +78,33 @@ static void RunCPUConvOpTest(const std::string& conv_op_type, const TestInputDef } // Creates a graph with a single Q/DQ Conv operator. Used for testing HTP backend. -template -static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& conv_op_type, const TestInputDef& input_def, - const TestInputDef& weights_def, - const TestInputDef& bias_def, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations, - const std::string& auto_pad = "NOTSET") { +template +static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& conv_op_type, + const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + const std::string& auto_pad = "NOTSET", + bool use_contrib_qdq = false) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, auto_pad](ModelTestBuilder& builder, - std::vector>& output_qparams) { + dilations, auto_pad, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { std::vector conv_inputs; // input -> Q/DQ -> auto* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - auto* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + auto* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); conv_inputs.push_back(input_qdq); // weights -> Q/DQ -> auto* weights = MakeTestInput(builder, weights_def); - QuantParams weights_qparams = GetTestInputQuantParams(weights_def); - auto* weights_qdq = AddQDQNodePair(builder, weights, weights_qparams.scale, weights_qparams.zero_point); + QuantParams weights_qparams = GetTestInputQuantParams(weights_def); + auto* weights_qdq = AddQDQNodePair(builder, weights, weights_qparams.scale, + weights_qparams.zero_point, use_contrib_qdq); conv_inputs.push_back(weights_qdq); // bias -> @@ -107,7 +112,7 @@ static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& con // Bias requirement taken from python quantization tool: onnx_quantizer.py::quantize_bias_static() const float bias_scale = input_qparams.scale * weights_qparams.scale; - conv_inputs.push_back(MakeTestQDQBiasInput(builder, bias_def, bias_scale)); + conv_inputs.push_back(MakeTestQDQBiasInput(builder, bias_def, bias_scale, use_contrib_qdq)); } auto* conv_output = builder.MakeIntermediate(); @@ -125,13 +130,14 @@ static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& con conv_node.AddAttribute("dilations", dilations); } - AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, output_qparams[0].zero_point); + AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); }; } // Runs a Conv model on the QNN HTP backend. Checks the graph node assignment, and that inference // outputs for QNN EP and CPU EP match. -template +template static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef& input_def, const TestInputDef& weights_def, const TestInputDef& bias_def, @@ -140,6 +146,7 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef const std::vector& dilations, const std::string& auto_pad, ExpectedEPNodeAssignment expected_ep_assignment, + bool use_contrib_qdq = false, int opset = 13, float fp32_abs_err = 1e-5f) { ProviderOptions provider_options; @@ -150,9 +157,11 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, auto_pad), - BuildQDQConvTestCase(conv_op_type, input_def, weights_def, bias_def, - strides, pads, dilations, auto_pad), + TestQDQModelAccuracy(BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, + auto_pad), + BuildQDQConvTestCase(conv_op_type, input_def, weights_def, + bias_def, strides, pads, dilations, + auto_pad, use_contrib_qdq), provider_options, opset, expected_ep_assignment, @@ -279,52 +288,56 @@ TEST_F(QnnCPUBackendTests, Convf32_large_input2_nopad_bias_initializer) { // Test 1D Conv with static weights (implemented in QNN EP as 2D convolution with height of 1). TEST_F(QnnCPUBackendTests, Conv1Df32_StaticWeights_DefaultBias) { + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; RunCPUConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}), // Dynamic input - TestInputDef({1, 2, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights - TestInputDef({1}, true, {1.0f}), // Bias of 1.f - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights + TestInputDef({1}, true, {1.0f}), // Initializer Bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations "NOTSET", ExpectedEPNodeAssignment::All); } // Test 1D Conv with dynamic weights (implemented in QNN EP as 2D convolution with height of 1). TEST_F(QnnCPUBackendTests, Conv1Df32_DynamicWeights_DefaultBias) { + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; RunCPUConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}), // Dynamic input - TestInputDef({1, 2, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights - TestInputDef(), // Default bias - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights + TestInputDef(), // Default bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations "NOTSET", ExpectedEPNodeAssignment::All); } // Test 1D ConvTranspose with static weights (implemented in QNN EP as 2D convolution with height of 1). TEST_F(QnnCPUBackendTests, ConvTranspose1Df32_StaticWeights_DefaultBias) { + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; RunCPUConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}), // Dynamic input - TestInputDef({2, 1, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights - TestInputDef({1}, true, {0.0f}), // Zero bias - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights + TestInputDef({1}, true, {0.0f}), // Zero bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations "NOTSET", ExpectedEPNodeAssignment::All); } // Test 1D ConvTranspose with dynamic weights (implemented in QNN EP as 2D convolution with height of 1). TEST_F(QnnCPUBackendTests, ConvTranspose1Df32_DynamicWeights_DefaultBias) { + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; RunCPUConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}), // Dynamic input - TestInputDef({2, 1, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights - TestInputDef({1}, true, {0.0f}), // Zero bias - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights + TestInputDef({1}, true, {0.0f}), // Zero bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations "NOTSET", ExpectedEPNodeAssignment::All); } @@ -397,218 +410,448 @@ TEST_F(QnnHTPBackendTests, Test_QDQConvWithDynamicWeightsFromMul) { // Check that QNN compiles DQ -> Conv -> Q as a single unit. // Tests bias as a dynamic input. -TEST_F(QnnHTPBackendTests, ConvU8S32_bias_dynamic_input) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 1, 5, 5}, false, 0.0f, 10.0f), // Random dynamic input - TestInputDef({1, 1, 3, 3}, true, -10.0f, 10.0f), // Random static input - TestInputDef({1}, false, {2.0f}), // Dynamic bias = 2.0f - {1, 1}, // Strides - {0, 0, 0, 0}, // Pads - {1, 1}, // Dilations - "NOTSET", - ExpectedEPNodeAssignment::All); +TEST_F(QnnHTPBackendTests, ConvU8U8S32_bias_dynamic_input) { + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({1, 1, 3, 3}, true, -10.0f, 10.0f), // Random static input + TestInputDef({1}, false, {2.0f}), // Dynamic bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All); +} + +// Tests 16-bit QDQ Conv with dynamic weights and bias (uses QNN's Conv2d) +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0040235077030956745, zero_point=0. +// Expected val: 87.354057312011719 +// QNN QDQ val: 0 (err 87.354057312011719) +// CPU QDQ val: 87.3583984375 (err 0.00434112548828125) +TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S16S32_DynamicBias) { + TestInputDef input_def({1, 2, 5, 5}, false, GetFloatDataInRange(-10.0f, 10.0f, 50)); + TestInputDef weight_def({1, 2, 3, 3}, false, GetFloatDataInRange(-1.0f, 5.0f, 18)); + RunHTPConvOpTest("Conv", + input_def, // Input + weight_def.OverrideValueRange(-5.0f, 5.0f), // Weights (symmetric quant range) + TestInputDef({1}, false, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft QDQ ops for 16-bit +} + +// Tests 16-bit QDQ Conv with dynamic weights and bias (uses QNN's DepthwiseConv2d) +// TODO(adrianlizarraga): FAIL: Failed to finalize QNN graph. Error code 1002 +TEST_F(QnnHTPBackendTests, DISABLED_DepthwiseConvU16S16S32_DynamicBias) { + TestInputDef input_def({1, 1, 5, 5}, false, GetFloatDataInRange(-10.0f, 10.0f, 25)); + TestInputDef weight_def({1, 1, 3, 3}, false, GetFloatDataInRange(-1.0f, 5.0f, 9)); + RunHTPConvOpTest("Conv", + input_def, // Input + weight_def.OverrideValueRange(-5.0f, 5.0f), // Weights (symmetric quant range) + TestInputDef({1}, false, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft QDQ ops for 16-bit +} + +// Tests 16-bit QDQ Conv with dynamic weights and no bias. +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0039929896593093872, zero_point=0. +// Expected val: 85.354057312011719 +// QNN QDQ val: 0 (err 85.354057312011719) +// CPU QDQ val: 85.358139038085938 (err 0.00408172607421875) +TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S16S32_NoBias) { + TestInputDef input_def({1, 2, 5, 5}, false, GetFloatDataInRange(-10.0f, 10.0f, 50)); + TestInputDef weight_def({1, 2, 3, 3}, false, GetFloatDataInRange(-1.0f, 5.0f, 18)); + RunHTPConvOpTest("Conv", + input_def, // Input + weight_def.OverrideValueRange(-5.0f, 5.0f), // Weights (symmetric quant range) + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft QDQ ops for 16-bit +} + +// Tests 16-bit QDQ Conv with dynamic weights and no bias (uses QNN's DepthWiseConv2d) +// TODO(adrianlizarraga): FAIL: Failed to finalize QNN graph. Error code 1002 +TEST_F(QnnHTPBackendTests, DISABLED_DepthwiseConvU16S16S32_NoBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 25); + std::vector weight_data = GetFloatDataInRange(-10.0f, 10.0f, 9); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, input_data), // Input + TestInputDef({1, 1, 3, 3}, false, weight_data), // Weights + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft QDQ ops for 16-bit +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with static bias. +// Uses QNN's DepthwiseConv2d operator. +// TODO: Inaccuracy detected for output 'output', element 8. +// Output quant params: scale=0.0027466239407658577, zero_point=10194. +// Expected val: 152 +// QNN QDQ val: 151.8004150390625 (err 0.1995849609375) +// CPU QDQ val: 151.9981689453125 (err 0.0018310546875) +TEST_F(QnnHTPBackendTests, DepthwiseConvU16U8S32_StaticBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 25); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 9); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, input_data), // Input + TestInputDef({1, 1, 3, 3}, true, weight_data), // Weights + TestInputDef({1}, true, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.2f); +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with static bias. +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0040235077030956745, zero_point=0. +// Expected val: 87.354057312011719 +// QNN QDQ val: 87.559577941894531 (err 0.2055206298828125) +// CPU QDQ val: 87.398635864257812 (err 0.04457855224609375) +TEST_F(QnnHTPBackendTests, ConvU16U8S32_StaticBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 50); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 18); + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 5, 5}, false, input_data), // Input + TestInputDef({1, 2, 3, 3}, true, weight_data), // Weights + TestInputDef({1}, true, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.6f); +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with dynamic bias. +// Uses QNN's DepthwiseConv2d operator. +// TODO: Inaccuracy detected for output 'output', element 1. +// Output quant params: scale=0.0027466239407658577, zero_point=10194. +// Expected val: -13.000001907348633 +// QNN QDQ val: -13.095903396606445 (err 0.0959014892578125) +// CPU QDQ val: -12.999771118164062 (err 0.0002307891845703125) +TEST_F(QnnHTPBackendTests, DepthwiseConvU16U8S32_DynamicBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 25); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 9); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, input_data), // Input + TestInputDef({1, 1, 3, 3}, true, weight_data), // Weights + TestInputDef({1}, false, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.2f); +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with dynamic bias. +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0040235077030956745, zero_point=0. +// Expected val: 87.354057312011719 +// QNN QDQ val: 87.559577941894531 (err 0.2055206298828125) +// CPU QDQ val: 87.398635864257812 (err 0.04457855224609375) +TEST_F(QnnHTPBackendTests, ConvU16U8S32_DynamicBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 50); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 18); + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 5, 5}, false, input_data), // Input + TestInputDef({1, 2, 3, 3}, true, weight_data), // Weights + TestInputDef({1}, false, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.57f); +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with no bias +// TODO: Inaccuracy detected for output 'output', element 7. +// Output quant params: scale=0.0039929896593093872, zero_point=0. +// Expected val: 246.98667907714844 +// QNN QDQ val: 247.82090759277344 (err 0.834228515625) +// CPU QDQ val: 247.24192810058594 (err 0.2552490234375) +TEST_F(QnnHTPBackendTests, ConvU16U8S32_NoBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 50); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 18); + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 5, 5}, false, input_data), // Input + TestInputDef({1, 2, 3, 3}, true, weight_data), // Weights + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.58f); +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with no bias +// Uses QNN's DepthwiseConv2d operator. +// TODO: Inaccuracy detected for output 'output', element 8. +// Output quant params: scale=0.0027466239407658577, zero_point=10923. +// Expected val: 150 +// QNN QDQ val: 149.80087280273438 (err 0.199127197265625) +// CPU QDQ val: 149.99862670898438 (err 0.001373291015625) +TEST_F(QnnHTPBackendTests, DepthwiseConvU16U8S32_NoBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 25); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 9); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, input_data), // Input + TestInputDef({1, 1, 3, 3}, true, weight_data), // Weights + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.2f); } // Test that dynamic weights with default bias works for Conv. This was previously not working // on older versions of QNN sdk. -TEST_F(QnnHTPBackendTests, ConvU8S32_DynamicWeight_NoBias) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 3, 32, 32}, false, -10.0f, 10.0f), // Random dynamic input - TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), // Random dynamic weights - TestInputDef(), // Default bias - {1, 1}, // Strides - {0, 0, 0, 0}, // Pads - {1, 1}, // Dilations - "NOTSET", - ExpectedEPNodeAssignment::All); +TEST_F(QnnHTPBackendTests, ConvU8U8S32_DynamicWeight_NoBias) { + RunHTPConvOpTest("Conv", + TestInputDef({1, 3, 32, 32}, false, -10.0f, 10.0f), // Input + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), // Weights + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All); } // Test that dynamic weights with default bias works for ConvTranspose. This was previously not working // on older versions of QNN sdk. -TEST_F(QnnHTPBackendTests, ConvTransposeU8S32_DynamicWeight_NoBias) { - RunHTPConvOpTest("ConvTranspose", - TestInputDef({1, 3, 32, 32}, false, -10.0f, 10.0f), // Random dynamic input - TestInputDef({3, 1, 4, 4}, false, -10.0f, 10.0f), // Random dynamic weights - TestInputDef(), // Default bias - {1, 1}, // Strides - {0, 0, 0, 0}, // Pads - {1, 1}, // Dilations - "NOTSET", - ExpectedEPNodeAssignment::All); +TEST_F(QnnHTPBackendTests, ConvTransposeU8U8S32_DynamicWeight_NoBias) { + RunHTPConvOpTest("ConvTranspose", + TestInputDef({1, 3, 32, 32}, false, -10.0f, 10.0f), // Input + TestInputDef({3, 1, 4, 4}, false, -10.0f, 10.0f), // Weights + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All); } // Check that QNN compiles DQ -> Conv -> Q as a single unit. // Tests bias as an initializer. TEST_F(QnnHTPBackendTests, ConvU8U8S32_bias_initializer) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 1, 5, 5}, false, 0.0f, 10.0f), // Random dynamic input - TestInputDef({1, 1, 3, 3}, true, -10.0f, 10.0f), // Random static weight - TestInputDef({1}, true, {2.0f}), // Initializer bias = 2.0f - {1, 1}, // Strides - {0, 0, 0, 0}, // Pads - {1, 1}, // Dilations - "NOTSET", - ExpectedEPNodeAssignment::All); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({1, 1, 3, 3}, true, -10.0f, 10.0f), // Random static weight + TestInputDef({1}, true, {2.0f}), // Initializer bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All); } // Tests 1D Conv with bias as an initializer. -TEST_F(QnnHTPBackendTests, Conv1DU8S32_bias_initializer) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0, 0}, // pads - {1}, // dilations - "NOTSET", - ExpectedEPNodeAssignment::All); +TEST_F(QnnHTPBackendTests, Conv1DU8U8S32_bias_initializer) { + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0, 0}, // pads + {1}, // dilations + "NOTSET", + ExpectedEPNodeAssignment::All); } // Tests 1D ConvTranspose with bias as an initializer. -TEST_F(QnnHTPBackendTests, ConvTranspose1DU8S32_bias_initializer) { - RunHTPConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0, 0}, // pads - {1}, // dilations - "NOTSET", - ExpectedEPNodeAssignment::All); +TEST_F(QnnHTPBackendTests, ConvTranspose1DU8U8S32_bias_initializer) { + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("ConvTranspose", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0, 0}, // pads + {1}, // dilations + "NOTSET", + ExpectedEPNodeAssignment::All); } // Tests auto_pad value "SAME_UPPER" on HTP backend (compares to CPU EP). -TEST_F(QnnHTPBackendTests, ConvU8S32_AutoPadUpper) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input - TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1, 1}, // strides - {}, // pads - {1, 1}, // dilations - "SAME_UPPER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); +TEST_F(QnnHTPBackendTests, ConvU8U8S32_AutoPadUpper) { + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input + TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1, 1}, // strides + {}, // pads + {1, 1}, // dilations + "SAME_UPPER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests Conv1d auto_pad value "SAME_UPPER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, Conv1DU8U8S32_AutoPadUpper) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0}, // pads - {1}, // dilations - "SAME_UPPER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0}, // pads + {1}, // dilations + "SAME_UPPER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests TransposeConv1d auto_pad value "SAME_UPPER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, ConvTranspose1DU8U8S32_AutoPadUpper) { - RunHTPConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0}, // pads - {1}, // dilations - "SAME_UPPER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("ConvTranspose", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0}, // pads + {1}, // dilations + "SAME_UPPER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests Conv's auto_pad value "SAME_LOWER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, ConvU8U8S32_AutoPadLower) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input - TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1, 1}, // strides - {}, // pads - {1, 1}, // dilations - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input + TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1, 1}, // strides + {}, // pads + {1, 1}, // dilations + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests ConvTranspose's auto_pad value "SAME_LOWER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, ConvTransposeU8U8S32_AutoPadLower) { - RunHTPConvOpTest("ConvTranspose", - TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input - TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1, 1}, // strides - {}, // pads - {1, 1}, // dilations - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + RunHTPConvOpTest("ConvTranspose", + TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input + TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1, 1}, // strides + {}, // pads + {1, 1}, // dilations + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests Conv1d auto_pad value "SAME_LOWER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, Conv1DU8U8S32_AutoPadLower) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0}, // pads - {1}, // dilations - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0}, // pads + {1}, // dilations + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests ConvTranspose 1d auto_pad value "SAME_LOWER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, ConvTranspose1DU8U8S32_AutoPadLower) { - RunHTPConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0}, // pads - {1}, // dilations - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("ConvTranspose", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0}, // pads + {1}, // dilations + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } TEST_F(QnnHTPBackendTests, ConvU8U8S32_large_input1_padding_bias_initializer) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 3, 60, 452}, false, 0.f, 10.f), // Dynamic input - TestInputDef({16, 3, 3, 3}, true, -1.f, 1.f), // Static weights - TestInputDef({16}, true, std::vector(16, 1.f)), // Initializer bias = 1.f, 1.f, ... - {1, 1}, - {1, 1, 1, 1}, - {1, 1}, - "NOTSET", - ExpectedEPNodeAssignment::All); -} - -TEST_F(QnnHTPBackendTests, ConvU8S32_large_input2_bias_initializer) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 128, 8, 56}, false, 0.f, 10.f), // Dynamic input - TestInputDef({32, 128, 1, 1}, true, -1.f, 1.f), // Random static weights - TestInputDef({32}, true, -1.f, 1.f), // Random initializer bias - {1, 1}, - {0, 0, 0, 0}, - {1, 1}, - "NOTSET", - ExpectedEPNodeAssignment::All); + RunHTPConvOpTest("Conv", + TestInputDef({1, 3, 60, 452}, false, 0.f, 10.f), // Dynamic input + TestInputDef({16, 3, 3, 3}, true, -1.f, 1.f), // Static weights + TestInputDef({16}, true, std::vector(16, 1.f)), // Initializer bias + {1, 1}, + {1, 1, 1, 1}, + {1, 1}, + "NOTSET", + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, ConvU8U8S32_large_input2_bias_initializer) { + RunHTPConvOpTest("Conv", + TestInputDef({1, 128, 8, 56}, false, 0.f, 10.f), // Dynamic input + TestInputDef({32, 128, 1, 1}, true, -1.f, 1.f), // Random static weights + TestInputDef({32}, true, -1.f, 1.f), // Random initializer bias + {1, 1}, + {0, 0, 0, 0}, + {1, 1}, + "NOTSET", + ExpectedEPNodeAssignment::All); } TEST_F(QnnHTPBackendTests, ConvU8U8S32_LargeInput_Dilations_Pads) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 3, 768, 1152}, false, 0.f, 10.f), // Dynamic input - TestInputDef({64, 3, 7, 7}, true, -1.f, 1.f), // Random static weights - TestInputDef({64}, true, -1.f, 1.f), // Random initializer bias - {2, 2}, // strides - {3, 3, 3, 3}, // pads - {1, 1}, // dilations - "NOTSET", // auto_pad - ExpectedEPNodeAssignment::All); + RunHTPConvOpTest("Conv", + TestInputDef({1, 3, 768, 1152}, false, 0.f, 10.f), // Dynamic input + TestInputDef({64, 3, 7, 7}, true, -1.f, 1.f), // Static weights + TestInputDef({64}, true, -1.f, 1.f), // Initializer bias + {2, 2}, // strides + {3, 3, 3, 3}, // pads + {1, 1}, // dilations + "NOTSET", // auto_pad + ExpectedEPNodeAssignment::All); } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index 6edb6ecdcfb1a..e721ccbcb45a9 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -27,28 +27,31 @@ static GetTestModelFn BuildMatMulOpTestCase(const TestInputDef& input1_de } // Returns a function that creates a graph with a QDQ MatMul operator. -template -static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDef& input1_def, - const TestInputDef& input2_def) { - return [input1_def, input2_def](ModelTestBuilder& builder, - std::vector>& output_qparams) { +template +static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDef& input1_def, + const TestInputDef& input2_def, + bool use_contrib_qdq) { + return [input1_def, input2_def, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { // input1 -> Q -> DQ -> NodeArg* input1 = MakeTestInput(builder, input1_def); - QuantParams input1_qparams = GetTestInputQuantParams(input1_def); - auto* input1_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + auto* input1_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point, + use_contrib_qdq); // input2 -> Q -> DQ -> NodeArg* input2 = MakeTestInput(builder, input2_def); - QuantParams input2_qparams = GetTestInputQuantParams(input2_def); - auto* input2_qdq = AddQDQNodePair(builder, input2, input2_qparams.scale, input2_qparams.zero_point); + QuantParams input2_qparams = GetTestInputQuantParams(input2_def); + auto* input2_qdq = AddQDQNodePair(builder, input2, input2_qparams.scale, input2_qparams.zero_point, + use_contrib_qdq); // MatMul auto* op_output = builder.MakeIntermediate(); builder.AddNode("MatMul", {input1_qdq, input2_qdq}, {op_output}); // op_output -> Q -> DQ -> output - AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, - output_qparams[0].zero_point); + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); }; } @@ -75,11 +78,13 @@ static void RunMatMulOpOpTest(const TestInputDef& input1_def, // Runs a QDQ MatMul model on the QNN HTP backend. Checks the graph node assignment, and that the // QDQ model is accurate on QNN EP (compared to CPU EP). -template +template static void RunQDQMatMulOpOpTest(const TestInputDef& input1_def, const TestInputDef& input2_def, ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 18) { + int opset = 18, + bool use_contrib_qdq = false, + float fp32_abs_err = 1e-4f) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -88,11 +93,12 @@ static void RunQDQMatMulOpOpTest(const TestInputDef& input1_def, #endif TestQDQModelAccuracy(BuildMatMulOpTestCase(input1_def, input2_def), - BuildMatMulOpQDQTestCase(input1_def, input2_def), + BuildMatMulOpQDQTestCase(input1_def, input2_def, + use_contrib_qdq), provider_options, opset, expected_ep_assignment, - 1e-5f); + fp32_abs_err); } // @@ -127,16 +133,68 @@ TEST_F(QnnCPUBackendTests, MatMulOp_Broadcast) { // TEST_F(QnnHTPBackendTests, MatMulOp_HTP_u8) { - RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}), - TestInputDef({3, 2}, false, {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}), - ExpectedEPNodeAssignment::All, 18); + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), + TestInputDef({3, 2}, false, input1_data), + ExpectedEPNodeAssignment::All, 18); } -// Test MatMul broadcasting +// Test QDQ MatMul with 16-bit act, 8-bit weights (static) +// TODO: (SLIGHT) Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0015259021893143654, zero_point=0. +// Expected val: 98 +// QNN QDQ val: 97.720298767089844 (err 0.27970123291015625) +// CPU QDQ val: 97.726402282714844 (err 0.27359771728515625) +TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), + TestInputDef({3, 2}, true, input1_data), + ExpectedEPNodeAssignment::All, + 18, + true, // Use com.microsoft Q/DQ ops + 7e-3f); +} + +// Test 16-bit QDQ MatMul with static weights +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0015259021893143654, zero_point=0. +// Expected val: 98 +// QNN QDQ val: 0.65461206436157227 (err 97.345390319824219) +// CPU QDQ val: 98.002593994140625 (err 0.002593994140625) +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_HTP_A16_W16) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), + TestInputDef({3, 2}, true, input1_data), + ExpectedEPNodeAssignment::All, + 18, + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ MatMul broadcasting TEST_F(QnnHTPBackendTests, MatMulOp_Broadcast) { - RunQDQMatMulOpOpTest(TestInputDef({28, 1, 64}, false, -10.0f, 10.0f), - TestInputDef({64, 32}, false, -10.0f, 10.0f), - ExpectedEPNodeAssignment::All, 18); + RunQDQMatMulOpOpTest(TestInputDef({28, 1, 64}, false, -10.0f, 10.0f), + TestInputDef({64, 32}, false, -10.0f, 10.0f), + ExpectedEPNodeAssignment::All, 18); +} + +// Test 16-bit QDQ MatMul broadcasting +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0028538699261844158, zero_point=6050. +// Expected val: 169.76341247558594 +// QNN QDQ val: -16.675161361694336 (err 186.43856811523438) +// CPU QDQ val: 169.762451171875 (err 0.0009613037109375) +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_Broadcast_A16_W16) { + std::vector input_a = GetFloatDataInRange(-10.0f, 10.0f, 28 * 64); + std::vector input_b = GetFloatDataInRange(-10.0f, 10.0f, 64 * 32); + + RunQDQMatMulOpOpTest(TestInputDef({28, 1, 64}, false, input_a), + TestInputDef({64, 32}, true, input_b), + ExpectedEPNodeAssignment::All, + 18, + true); // Use com.microsoft Q/DQ ops } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 80b929e9dafbe..a441e828c0cc6 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -260,8 +260,8 @@ TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_sizes_opset18) { TEST_F(QnnHTPBackendTests, TestNHWCResizeShapeInference_qdq_sizes_opset18) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", true); } -#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 548f80675a622..724e9a11cd781 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -116,7 +116,8 @@ void InferenceModel(const std::string& model_data, const char* log_id, ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &output_vals)); } -NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef& bias_def, float bias_scale) { +NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef& bias_def, float bias_scale, + bool use_contrib_qdq) { NodeArg* bias_int32 = nullptr; // Bias must be int32 to be detected as a QDQ node unit. @@ -124,7 +125,8 @@ NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef bias_int32_def(bias_def.GetShape(), bias_def.IsInitializer(), static_cast(rand_info.min / bias_scale), + TestInputDef bias_int32_def(bias_def.GetShape(), bias_def.IsInitializer(), + static_cast(rand_info.min / bias_scale), static_cast(rand_info.max / bias_scale)); bias_int32 = MakeTestInput(builder, bias_int32_def); } else { @@ -143,7 +145,7 @@ NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef(bias_int32, bias_scale, 0, bias); + builder.AddDequantizeLinearNode(bias_int32, bias_scale, 0, bias, use_contrib_qdq); return bias; } diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 1b0b85319918f..fd572fa17f2b1 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -266,6 +266,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe std::vector output_names; InferenceModel(f32_model_data, "f32_model_logger", nullptr, ExpectedEPNodeAssignment::All, f32_helper.feeds_, output_names, cpu_f32_outputs); + ASSERT_FALSE(cpu_f32_outputs.empty()); + const size_t num_outputs = cpu_f32_outputs.size(); // Compute output range(s) and quantization params. @@ -432,7 +434,8 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef manual quantization (int32) => DQ => final float bias -NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef& bias_def, float bias_scale); +NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef& bias_def, float bias_scale, + bool use_contrib_qdq = false); /** * Returns a function that builds a model with a single operator with N inputs of the same element type. @@ -479,9 +482,10 @@ template inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_type, const std::vector>& input_defs, const std::vector& attrs, - const std::string& op_domain = kOnnxDomain) { - return [op_type, input_defs, attrs, op_domain](ModelTestBuilder& builder, - std::vector>& output_qparams) { + const std::string& op_domain = kOnnxDomain, + bool use_contrib_qdq = false) { + return [op_type, input_defs, attrs, op_domain, + use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { std::vector op_inputs; op_inputs.reserve(input_defs.size()); @@ -489,7 +493,7 @@ inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_ty NodeArg* input = MakeTestInput(builder, input_def); QuantParams input_qparams = GetTestInputQuantParams(input_def); NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, - input_qparams.zero_point); + input_qparams.zero_point, use_contrib_qdq); op_inputs.push_back(input_after_qdq); } @@ -503,7 +507,7 @@ inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_ty // op_output -> Q -> DQ -> output AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, - output_qparams[0].zero_point); + output_qparams[0].zero_point, use_contrib_qdq); }; } @@ -563,4 +567,4 @@ bool ReduceOpHasAxesInput(const std::string& op_type, int opset_version); } // namespace test } // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD) \ No newline at end of file +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/reduce_op_test.cc b/onnxruntime/test/providers/qnn/reduce_op_test.cc index c3c2b578a1bd0..57252f93492e5 100644 --- a/onnxruntime/test/providers/qnn/reduce_op_test.cc +++ b/onnxruntime/test/providers/qnn/reduce_op_test.cc @@ -648,4 +648,4 @@ TEST_F(QnnHTPBackendTests, ReduceMeanS8Opset18) { } // namespace test } // namespace onnxruntime -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index eed12af3c703c..63498982930f5 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -104,6 +104,7 @@ static void RunQDQOpTest(const std::string& op_type, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, const std::string& op_domain = kOnnxDomain, + bool use_contrib_qdq = false, float fp32_abs_err = 1e-4f) { ProviderOptions provider_options; #if defined(_WIN32) @@ -113,7 +114,7 @@ static void RunQDQOpTest(const std::string& op_type, #endif TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, attrs, op_domain), - BuildQDQOpTestCase(op_type, input_defs, attrs, op_domain), + BuildQDQOpTestCase(op_type, input_defs, attrs, op_domain, use_contrib_qdq), provider_options, opset_version, expected_ep_assignment, @@ -151,6 +152,17 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Sigmoid) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Sigmoid. +TEST_F(QnnHTPBackendTests, UnaryOp_Sigmoid_U16) { + RunQDQOpTest("Sigmoid", + {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, + {}, + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use MS domain Q/DQ ops +} + // Test the accuracy of QDQ Tanh. TEST_F(QnnHTPBackendTests, UnaryOp_Tanh) { RunQDQOpTest("Tanh", @@ -160,6 +172,17 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Tanh) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Tanh. +TEST_F(QnnHTPBackendTests, UnaryOp_Tanh_U16) { + RunQDQOpTest("Tanh", + {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, + {}, + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use MS domain Q/DQ ops +} + // Check that QNN compiles DQ -> Gelu -> Q as a single unit. // Use an input of rank 3. TEST_F(QnnHTPBackendTests, UnaryOp_Gelu) { @@ -171,6 +194,24 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Gelu) { kMSDomain); // GeLu is a contrib op. } +// Tests accuracy of 16-bit QDQ GeLu. +// TODO(adrianlizarraga): Inaccuracy detected for output 'output', element 5. +// Output quant params: scale=0.00015259021893143654, zero_point=0. +// Expected val: 10 +// QNN QDQ val: 9.997406005859375 (err 0.002593994140625) +// CPU QDQ val: 9.999847412109375 (err 0.000152587890625) +TEST_F(QnnHTPBackendTests, UnaryOp_Gelu_U16) { + const std::vector input_data = {-10.0f, -8.4f, 0.0f, 4.3f, 7.1f, 10.0f}; + RunQDQOpTest("Gelu", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 11, + ExpectedEPNodeAssignment::All, + kMSDomain, // GeLu is a contrib op. + true, // Use MS domain Q/DQ ops. + 0.0025f); // TODO(adrianlizarraga): Accuracy +} + // Check that QNN compiles DQ -> Elu -> Q as a single unit. // Use an input of rank 3. TEST_F(QnnHTPBackendTests, UnaryOp_Elu) { @@ -181,6 +222,23 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Elu) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Elu. +// TODO(adrianlizarraga): Re-enable. This works on QNN SDK 2.14.1! +// Inaccuracy detected for output 'output', element 1. +// Output quant params: scale=0.00011093531065853313, zero_point=8992. +// Expected val: -0.99751651287078857 +// QNN QDQ val: 6.2726154327392578 (err 7.2701320648193359) +// CPU QDQ val: -0.99753034114837646 (err 1.3828277587890625e-05) +TEST_F(QnnHTPBackendTests, DISABLE_UnaryOp_Elu_U16) { + RunQDQOpTest("Elu", + {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, + {}, + 11, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); +} + // Tests accuracy of QDQ Relu // TODO: Relu does not set negative values to zero! // Could be due to ORT's ReluQuantFusion! @@ -208,6 +266,24 @@ TEST_F(QnnHTPBackendTests, UnaryOp_HardSwish) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ HardSwish +// TODO(adrianlizarraga): Inaccuracy detected for output 'output', element 5. +// Output quant params: scale=0.00015259021893143654, zero_point=0. +// Expected val: 10 +// QNN QDQ val: 9.999237060546875 (err 0.000762939453125) +// CPU QDQ val: 9.999847412109375 (err 0.000152587890625) +TEST_F(QnnHTPBackendTests, UnaryOp_HardSwish_U16) { + const std::vector input_data = {-10.0f, -8.4f, 0.0f, 4.3f, 7.1f, 10.0f}; + RunQDQOpTest("HardSwish", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 14, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true, + 0.001f); // TODO(adrianlizarraga): Remove additional tolerance needed for inaccuracy +} + // Check that QNN compiles DQ -> Atan -> Q as a single unit. // Use an input of rank 3. TEST_F(QnnHTPBackendTests, UnaryOp_Atan) { @@ -218,6 +294,24 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Atan) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Atan +// TODO(adrianlizarraga): Inaccuracy detected for output 'output', element 1. +// Output quant params: scale=4.4895936298416927e-05, zero_point=32768. +// Expected val: -1.4219063520431519 +// QNN QDQ val: -1.4220787286758423 (err 0.00017237663269042969) +// CPU QDQ val: -1.4218991994857788 (err 7.152557373046875e-06) +TEST_F(QnnHTPBackendTests, UnaryOp_Atan_U16) { + const std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + RunQDQOpTest("Atan", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 14, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Atan domain + true, // Q/DQ op domain is com.microsoft + 1.8e-4f); +} + // Check that QNN compiles DQ -> Asin -> Q as a single unit. // Use an input of rank 3. TEST_F(QnnHTPBackendTests, UnaryOp_Asin) { @@ -238,6 +332,18 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Sign) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Sign +TEST_F(QnnHTPBackendTests, UnaryOp_Sign_U16) { + const std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + RunQDQOpTest("Sign", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Sign op domain + true); // Use com.microsoft Q/DQ op domains +} + // Check that QNN compiles DQ -> Sin -> Q as a single unit. // Use an input of rank 3. TEST_F(QnnHTPBackendTests, UnaryOp_Sin) { @@ -260,7 +366,7 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Cos) { // Check that QNN compiles DQ -> Cos -> Q as a single unit. // Use an input of rank 3. -TEST_F(QnnHTPBackendTests, UnaryOp_Cos_Inaccurate) { +TEST_F(QnnHTPBackendTests, UnaryOp_Cos_InaccurateFixed) { RunQDQOpTest("Cos", {TestInputDef({1, 2, 3}, false, {-3.14159f, -1.88436f, -0.542863f, 0.0f, 1.05622f, 3.14159f})}, {}, @@ -326,6 +432,18 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Round) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Log +TEST_F(QnnHTPBackendTests, UnaryOp_Log_U16) { + const std::vector input_data = GetFloatDataInRange(1.0f, 128.0f, 6); + RunQDQOpTest("Log", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 11, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Log op domain + true); // Use com.microsoft domain for Q/DQ ops +} + // Check that QNN compiles DQ -> Softmax -> Q as a single unit. // Test that the default axis (-1) for SoftMax opset 13 works. TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_DefaultAxis) { @@ -336,6 +454,18 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_DefaultAxis) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Softmax (opset 13) with default axis +TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_U16_DefaultAxis) { + const std::vector input_data = GetFloatDataInRange(-5.0f, 5.0f, 6); + RunQDQOpTest("Softmax", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, // Uses default axis of -1 for opset 13 + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Sofmax's domain + true); // Use com.microsoft domain for Q/DQ ops +} + // Check that QNN compiles DQ -> Softmax -> Q as a single unit. // Test that an axis != -1 is not supported. TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_UnsupportedAxis) { @@ -410,7 +540,7 @@ TEST_F(QnnHTPBackendTests, UnaryOp_LogSoftmax11_SetValidAxis) { ExpectedEPNodeAssignment::All); } -// Test QDQ Abs op. +// Test accuracy of QDQ Abs op. TEST_F(QnnHTPBackendTests, UnaryOp_Abs) { RunQDQOpTest("Abs", {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, @@ -419,7 +549,19 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Abs) { ExpectedEPNodeAssignment::All); } -// Test QDQ Ceil op. +// Test accuracy of 16-bit QDQ Abs op. +TEST_F(QnnHTPBackendTests, UnaryOp_Abs_U16) { + const std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + RunQDQOpTest("Abs", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Abs op's domain + true); // Use com.microsoft domain for Q/DQ ops +} + +// Test accuracy of QDQ Ceil op. TEST_F(QnnHTPBackendTests, UnaryOp_Ceil) { const std::vector input_data = GetFloatDataInRange(-12.0f, 12.0f, 6); RunQDQOpTest("Ceil", @@ -429,6 +571,18 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Ceil) { ExpectedEPNodeAssignment::All); } +// Test accuracy of 16-bit QDQ Ceil op. +TEST_F(QnnHTPBackendTests, UnaryOp_Ceil_U16) { + const std::vector input_data = GetFloatDataInRange(-12.0f, 12.0f, 6); + RunQDQOpTest("Ceil", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Ceil op's domain + true); // Use com.microsoft domain for Q/DQ ops +} + // Test QDQ Floor op. TEST_F(QnnHTPBackendTests, UnaryOp_Floor) { const std::vector input_data = GetFloatDataInRange(-12.0f, 12.0f, 6); @@ -457,6 +611,26 @@ TEST_F(QnnHTPBackendTests, DepthToSpaceOp_CRD) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ DepthToSpace. +TEST_F(QnnHTPBackendTests, DepthToSpaceOp_U16_CRD) { + const std::vector X = {0., 1., 2., + 3., 4., 5., + 9., 10., 11., + 12., 13., 14., + 18., 19., 20., + 21., 22., 23., + 27., 28., 29., + 30., 31., 32.}; + RunQDQOpTest("DepthToSpace", + {TestInputDef({1, 4, 2, 3}, false, X)}, + {utils::MakeAttribute("blocksize", static_cast(2)), + utils::MakeAttribute("mode", "CRD")}, + 11, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Op's domain + true); // Use com.microsoft domain for Q/DQ ops +} + // Test QDQ DepthToSpace. TEST_F(QnnHTPBackendTests, DepthToSpaceOp_DCR) { const std::vector X = {0., 1., 2., @@ -489,6 +663,22 @@ TEST_F(QnnHTPBackendTests, SpaceToDepthOp) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ SpaceToDepth. +TEST_F(QnnHTPBackendTests, SpaceToDepthOp_U16) { + const std::vector X = {0.0f, 0.1f, 0.2f, 0.3f, + 1.0f, 1.1f, 1.2f, 1.3f, + + 2.0f, 2.1f, 2.2f, 2.3f, + 3.0f, 3.1f, 3.2f, 3.3f}; + RunQDQOpTest("SpaceToDepth", + {TestInputDef({1, 2, 2, 4}, false, X)}, + {utils::MakeAttribute("blocksize", static_cast(2))}, + 11, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Op's domain + true); // Use com.microsoft domain for Q/DQ ops +} + // Run QDQ model on HTP twice // 1st run will generate the Qnn context cache binary file // 2nd run will load and run from Qnn context cache binary file @@ -561,7 +751,7 @@ TEST_F(QnnHTPBackendTests, QuantAccuracyTest) { ExpectedEPNodeAssignment::All); } -// Test QDQ Add +// Test 8-bit QDQ Add TEST_F(QnnHTPBackendTests, BinaryOp_Add4D) { RunQDQOpTest("Add", {TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f), @@ -571,7 +761,20 @@ TEST_F(QnnHTPBackendTests, BinaryOp_Add4D) { ExpectedEPNodeAssignment::All); } -// Test QDQ Sub +// Test 16-bit QDQ Add +TEST_F(QnnHTPBackendTests, BinaryOp_Add4D_U16) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunQDQOpTest("Add", + {TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({1, 2, 2, 2}, false, input_data)}, + {}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Sub TEST_F(QnnHTPBackendTests, BinaryOp_Sub4D) { RunQDQOpTest("Sub", {TestInputDef({1, 3, 8, 8}, false, -10.0f, 10.0f), @@ -581,6 +784,20 @@ TEST_F(QnnHTPBackendTests, BinaryOp_Sub4D) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ Sub +TEST_F(QnnHTPBackendTests, BinaryOp_Sub4D_U16) { + std::vector input0_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + std::vector input1_data = GetFloatDataInRange(0.0f, 20.0f, 8); + RunQDQOpTest("Sub", + {TestInputDef({1, 2, 2, 2}, false, input0_data), + TestInputDef({1, 2, 2, 2}, false, input1_data)}, + {}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + TEST_F(QnnHTPBackendTests, BinaryOp_Sub4D_LargeInputs) { RunQDQOpTest("Sub", {TestInputDef({1, 3, 768, 1152}, false, -1.0f, 1.0f), @@ -656,6 +873,20 @@ TEST_F(QnnHTPBackendTests, BinaryOp_Div4D_SmallInputs) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ Sub with small input values. +TEST_F(QnnHTPBackendTests, BinaryOp_Div4D_U16_SmallInputs) { + std::vector input0_data = {-10.0f, -8.0f, -1.0f, 0.0f, 1.0f, 2.1f, 8.0f, 10.0f}; + std::vector input1_data = {5.0f, 4.0f, 1.0f, 1.0f, 1.0f, 4.0f, 4.0f, 5.0f}; + RunQDQOpTest("Div", + {TestInputDef({1, 2, 2, 2}, false, input0_data), + TestInputDef({1, 2, 2, 2}, false, input1_data)}, + {}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + // TODO: Enable when this is fixed. // QNN v2.13: Inaccuracy detected for output 'output', element 2551923. // Output quant params: scale=4100.92626953125, zero_point=126. @@ -680,7 +911,7 @@ TEST_F(QnnHTPBackendTests, BinaryOp_Div4D_Broadcast) { ExpectedEPNodeAssignment::All); } -// Test QDQ Mul +// Test 8-bit QDQ Mul TEST_F(QnnHTPBackendTests, BinaryOp_Mul4D) { std::vector input_data = GetFloatDataInRange(-10.0, 10.0f, 8); RunQDQOpTest("Mul", @@ -691,6 +922,19 @@ TEST_F(QnnHTPBackendTests, BinaryOp_Mul4D) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ Mul +TEST_F(QnnHTPBackendTests, BinaryOp_Mul4D_U16) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunQDQOpTest("Mul", + {TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({1, 2, 2, 2}, false, input_data)}, + {}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + // Test And TEST_F(QnnHTPBackendTests, BinaryOp_And4D) { RunOpTest("And", @@ -711,7 +955,7 @@ TEST_F(QnnHTPBackendTests, BinaryOp_HTP_Or_Unsupported) { ExpectedEPNodeAssignment::None); } -// Test QDQ GridSample with bilinear +// Test 8-bit QDQ GridSample with bilinear TEST_F(QnnHTPBackendTests, GridSample_Bilinear) { RunQDQOpTest("GridSample", {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), @@ -723,7 +967,21 @@ TEST_F(QnnHTPBackendTests, GridSample_Bilinear) { ExpectedEPNodeAssignment::All); } -// Test QDQ GridSample with align corners +// Test 16-bit QDQ GridSample with bilinear +TEST_F(QnnHTPBackendTests, GridSample_U16_Bilinear) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), + TestInputDef({1, 2, 4, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 16))}, + {utils::MakeAttribute("align_corners", static_cast(0)), + utils::MakeAttribute("mode", "bilinear"), + utils::MakeAttribute("padding_mode", "zeros")}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ GridSample with align corners TEST_F(QnnHTPBackendTests, GridSample_AlignCorners) { RunQDQOpTest("GridSample", {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), @@ -735,6 +993,20 @@ TEST_F(QnnHTPBackendTests, GridSample_AlignCorners) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ GridSample with align corners +TEST_F(QnnHTPBackendTests, GridSample_U16_AlignCorners) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), + TestInputDef({1, 2, 4, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 16))}, + {utils::MakeAttribute("align_corners", static_cast(1)), + utils::MakeAttribute("mode", "bilinear"), + utils::MakeAttribute("padding_mode", "zeros")}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + // Test QDQ GridSample with padding mode: border // Inaccuracy detected for output 'output', element 0. // Output quant params: scale=0.046370312571525574, zero_point=129. @@ -751,7 +1023,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_GridSample_BorderPadding) { ExpectedEPNodeAssignment::All); } -// Test QDQ GridSample with nearest mode +// Test 8-bit QDQ GridSample with nearest mode TEST_F(QnnHTPBackendTests, GridSample_Nearest) { RunQDQOpTest("GridSample", {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), @@ -761,6 +1033,18 @@ TEST_F(QnnHTPBackendTests, GridSample_Nearest) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ GridSample with nearest mode +TEST_F(QnnHTPBackendTests, GridSample_U16_Nearest) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), + TestInputDef({1, 2, 4, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 16))}, + {utils::MakeAttribute("mode", "nearest")}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); +} + // Test QDQ GridSample with reflection padding mode // Inaccuracy detected for output 'output', element 2. // Output quant params: scale=0.024269860237836838, zero_point=0. @@ -801,4 +1085,4 @@ TEST_F(QnnHTPBackendTests, VariadicOp_Concat_2Inputs_2ndAxis) { } // namespace test } // namespace onnxruntime -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 3c5f516af4846..5c2db435d7fb5 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -566,28 +566,30 @@ def construct_model_conv_relu(self, output_model_path, input_shape, weight_shape onnx.save(model, output_model_path) - def verify(self, per_channel, is_quant_type_int8): + def verify_qdq(self, per_channel, activation_type, weight_type, extra_options=None): np.random.seed(1) model_fp32_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_fp32.{per_channel}.onnx") - model_int8_qdq_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_quant_qdq.{per_channel}.onnx") - model_int8_qop_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_quant_qop.{per_channel}.onnx") + model_qdq_path = str( + Path(self._tmp_model_dir.name) / f"conv_relu_quant_qdq.{activation_type}.{weight_type}.{per_channel}.onnx" + ) data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]}) self.construct_model_conv_relu(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31]) quantize_static( model_fp32_path, - model_int8_qdq_path, + model_qdq_path, data_reader, quant_format=QuantFormat.QDQ, per_channel=per_channel, reduce_range=per_channel, - activation_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, - weight_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, ) data_reader.rewind() # topo sort check check_op_type_order( self, - model_int8_qdq_path, + model_qdq_path, [ "DequantizeLinear", "QuantizeLinear", @@ -597,9 +599,15 @@ def verify(self, per_channel, is_quant_type_int8): "DequantizeLinear", ], ) - check_model_correctness(self, model_fp32_path, model_int8_qdq_path, data_reader.get_next()) + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + + def verify_qop(self, per_channel, is_quant_type_int8): + np.random.seed(1) + model_fp32_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_fp32.{per_channel}.onnx") + model_int8_qop_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_quant_qop.{per_channel}.onnx") + data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]}) + self.construct_model_conv_relu(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31]) - data_reader.rewind() quantize_static( model_fp32_path, model_int8_qop_path, @@ -617,10 +625,25 @@ def verify(self, per_channel, is_quant_type_int8): def test_quantize_conv_without_bias(self): # only test cases per_channel=True and reduce_range=True to avoid saturation on avx2 and avx512 for weight type int8 - self.verify(True, True) # per_channel:False, is_quant_type_int8:True + self.verify_qdq(True, QuantType.QInt8, QuantType.QInt8) # per_channel:True + self.verify_qop(True, True) # per_channel:True, is_quant_type_int8:True - self.verify(False, False) # per_channel:False, is_quant_type_int8:False - self.verify(True, False) # per_channel:True, is_quant_type_int8:False + self.verify_qdq(False, QuantType.QUInt8, QuantType.QUInt8) # per_channel:False + self.verify_qop(False, False) # per_channel:False, is_quant_type_int8:False + + self.verify_qdq(True, QuantType.QUInt8, QuantType.QUInt8) # per_channel:True + self.verify_qop(True, False) # per_channel:True, is_quant_type_int8:False + + # 16-bit QDQ via contrib ops + self.verify_qdq(False, QuantType.QUInt16, QuantType.QUInt16, {"UseQDQContribOps": True}) + self.verify_qdq(False, QuantType.QInt16, QuantType.QInt16, {"UseQDQContribOps": True}) + self.verify_qdq(False, QuantType.QUInt16, QuantType.QUInt8, {"UseQDQContribOps": True}) + self.verify_qdq(False, QuantType.QInt16, QuantType.QInt8, {"UseQDQContribOps": True}) + + self.verify_qdq(True, QuantType.QUInt16, QuantType.QUInt16, {"UseQDQContribOps": True}) + self.verify_qdq(True, QuantType.QInt16, QuantType.QInt16, {"UseQDQContribOps": True}) + self.verify_qdq(True, QuantType.QUInt16, QuantType.QUInt8, {"UseQDQContribOps": True}) + self.verify_qdq(True, QuantType.QInt16, QuantType.QInt8, {"UseQDQContribOps": True}) def test_quantize_relu_conv(self): float_model_path = str(Path(self._tmp_model_dir.name) / "float_relu_convs_model.onnx") diff --git a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py index a378721932a35..5e20d6b4e692a 100644 --- a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py +++ b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py @@ -49,9 +49,7 @@ def to_numpy(tensor): # These set of tests verify ONNX model export and compares outputs between # PyTorch and ORT. class ONNXExporterTest(unittest.TestCase): - from torch.onnx.symbolic_helper import _export_onnx_opset_version - - opset_version = _export_onnx_opset_version + opset_version = 17 keep_initializers_as_inputs = True # For IR version 3 type export. def setUp(self): diff --git a/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py b/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py index 3df127f5d356d..f74342403f4c3 100644 --- a/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py +++ b/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py @@ -1,59 +1,154 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """ Loads a model and updates the domain of QuantizeLinear and DequantizeLinear nodes to 'com.microsoft'. +Optionally updates zero-points to 16bit data types. + This is used to create models for testing QDQ transformations with the contrib QDQ ops. -Usage: python3 convert_qdq_ops_to_ms_domain.py +Usage: +python3 convert_qdq_ops_to_ms_domain.py --input_model --output_model --use_16bit_qdq Models created with this script: - qdq_with_multi_consumer_dq_nodes.fixed.qdq_contrib.onnx +- qdq_with_multi_consumer_dq_nodes.fixed.qdq16_contrib.onnx - fusion/constant_folding_dequantizelinear.qdq_contrib.onnx +- fusion/constant_folding_dequantizelinear.qdq16_contrib.onnx - fusion/constant_folding_qdq_node_unit.qdq_contrib.onnx +- fusion/constant_folding_qdq_node_unit.qdq16_contrib.onnx - fusion/constant_folding_qdq_node_unit.graph_output.qdq_contrib.onnx +- fusion/constant_folding_qdq_node_unit.graph_output.qdq16_contrib.onnx """ +from __future__ import annotations + +import argparse import os +import struct import sys import onnx +from onnx import shape_inference QDQ_OPS = ("QuantizeLinear", "DequantizeLinear") - - -def print_usage(prog_name: str): - """ - Prints the program's command-line arguments and usage. - """ - - print(f"Usage: {prog_name} ") - - -def update_qdq_node_domains(graph): - """ - Updates the domain of all QuantizeLinear and DequantizeLinear nodes - in a graph to 'com.microsoft'. - """ +QDQ_CONVERT_TYPES = {onnx.TensorProto.UINT8: onnx.TensorProto.UINT16, onnx.TensorProto.INT8: onnx.TensorProto.INT16} +TYPE_TO_STRUCT_LABEL = { + onnx.TensorProto.UINT8: "B", + onnx.TensorProto.INT8: "b", + onnx.TensorProto.UINT16: "H", + onnx.TensorProto.INT16: "h", +} + + +def convert_initializer_to_16bits(initializer: onnx.TensorProto, target_type: onnx.TensorProto.DataType): + byte_order = ">" if sys.byteorder == "big" else "<" + byte_label = TYPE_TO_STRUCT_LABEL[initializer.data_type] + short_label = TYPE_TO_STRUCT_LABEL[target_type] + + # Do not support external data + if initializer.HasField("data_location") and initializer.data_location == onnx.TensorProto.EXTERNAL: + raise Exception("Do not support initializers with external data") + + # Need to convert raw_data bytes to 16-bit values. + # NOTE: For tensors that use .int32_data instead of .raw_data, we don't need any special handling + # other than updating the data type. This is because the upper 24 bits are already cleared to zero. + if initializer.HasField("raw_data"): + num_byte_vals = len(initializer.raw_data) + + # Extract 8-bit values as int32s + int32_vals = struct.unpack(f"{byte_order}{num_byte_vals}{byte_label}", initializer.raw_data) + + # Repack int32 values as 16-bit values + initializer.raw_data = struct.pack(f"{byte_order}{num_byte_vals}{short_label}", *int32_vals) + + initializer.data_type = target_type + + +def convert_qdq_op_to_16bit( + name_to_initializer: dict[str, onnx.TensorProto], + name_to_values: dict[str, onnx.ValueInfoProto], + name_to_inputs: dict[str, onnx.ValueInfoProto], + name_to_outputs: dict[str, onnx.ValueInfoProto], + node: onnx.NodeProto, +): + zp_input = node.input[2] if len(node.input) > 2 else None + + if zp_input in name_to_initializer: + zp_initializer = name_to_initializer[zp_input] + + zp_target_type = QDQ_CONVERT_TYPES.get(zp_initializer.data_type) + if zp_target_type: + convert_initializer_to_16bits(zp_initializer, zp_target_type) + + if node.op_type == "DequantizeLinear": + input0 = node.input[0] + + if input0 in name_to_initializer: + input_initializer = name_to_initializer[input0] + input_target_type = QDQ_CONVERT_TYPES.get(input_initializer.data_type) + if input_target_type: + convert_initializer_to_16bits(input_initializer, input_target_type) + elif input0 in name_to_values: + input_val = name_to_values[input0] + input_target_type = QDQ_CONVERT_TYPES.get(input_val.type.tensor_type.elem_type) + if input_target_type: + input_val.type.tensor_type.elem_type = input_target_type + elif input0 in name_to_inputs: + input_val = name_to_inputs[input0] + input_target_type = QDQ_CONVERT_TYPES.get(input_val.type.tensor_type.elem_type) + if input_target_type: + input_val.type.tensor_type.elem_type = input_target_type + else: + # QuantizeLinear + output0 = node.output[0] + + if output0 in name_to_values: + output_val = name_to_values[output0] + output_target_type = QDQ_CONVERT_TYPES.get(output_val.type.tensor_type.elem_type) + if output_target_type: + output_val.type.tensor_type.elem_type = output_target_type + elif output0 in name_to_outputs: + output_val = name_to_outputs[output0] + output_target_type = QDQ_CONVERT_TYPES.get(output_val.type.tensor_type.elem_type) + if output_target_type: + output_val.type.tensor_type.elem_type = output_target_type + else: + raise Exception("Only support Q/DQ ops with explicit zero-point inputs") + + +def update_qdq_node_domains(graph: onnx.GraphProto, use_16bit_qdq: bool): + name_to_initializer = {initializer.name: initializer for initializer in graph.initializer} + name_to_values = {value.name: value for value in graph.value_info} + name_to_inputs = {g_input.name: g_input for g_input in graph.input} + name_to_outputs = {g_output.name: g_output for g_output in graph.output} for node in graph.node: # Handle subgraphs: for attr in node.attribute: if attr.type == onnx.AttributeProto.GRAPH: - update_qdq_node_domains(attr.g) + update_qdq_node_domains(attr.g, use_16bit_qdq) elif attr.type == onnx.AttributeProto.GRAPHS: for subgraph in attr.graphs: - update_qdq_node_domains(subgraph) + update_qdq_node_domains(subgraph, use_16bit_qdq) # Update Q/DQ domains if node.op_type in QDQ_OPS: node.domain = "com.microsoft" + if use_16bit_qdq: + convert_qdq_op_to_16bit(name_to_initializer, name_to_values, name_to_inputs, name_to_outputs, node) + def main(): - prog_name, *argv = sys.argv + parser = argparse.ArgumentParser(description="Convert Q/DQ ops to com.microsoft domain (or 16-bit)") + parser.add_argument("--input_model", type=str, required=True, help="Input onnx model path") + parser.add_argument("--output_model", type=str, required=False, help="Output onnx model path") + parser.add_argument("--use_16bit_qdq", required=False, action="store_true", help="Convert to 16-bit QDQ") - if len(argv) != 1: - print_usage(prog_name) - sys.exit(1) + args = parser.parse_args() - model = onnx.load(argv[0]) + model = onnx.load(args.input_model) has_ms_domain = False for opset in model.opset_import: @@ -64,10 +159,18 @@ def main(): if not has_ms_domain: model.opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) - update_qdq_node_domains(model.graph) + update_qdq_node_domains(model.graph, args.use_16bit_qdq) + model = shape_inference.infer_shapes(model) onnx.checker.check_model(model, True) - base_model_name = os.path.splitext(argv[0])[0] - onnx.save_model(model, base_model_name + ".qdq_contrib.onnx") + + output_model_path = args.output_model + if not output_model_path: + base_model_name = os.path.splitext(args.input_model)[0] + suffix = ".qdq16_contrib" if args.use_16bit_qdq else ".qdq_contrib" + output_model_path = base_model_name + suffix + ".onnx" + + onnx.save_model(model, output_model_path) + print(f"[INFO] Saved model: {output_model_path}") if __name__ == "__main__": diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_dequantizelinear.qdq16_contrib.onnx b/onnxruntime/test/testdata/transform/fusion/constant_folding_dequantizelinear.qdq16_contrib.onnx new file mode 100644 index 0000000000000..8fc884024b00f Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/constant_folding_dequantizelinear.qdq16_contrib.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.qdq16_contrib.onnx b/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.qdq16_contrib.onnx new file mode 100644 index 0000000000000..b9cae7f59f8e8 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.qdq16_contrib.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.qdq16_contrib.onnx b/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.qdq16_contrib.onnx new file mode 100644 index 0000000000000..8e12e10e90531 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.qdq16_contrib.onnx differ diff --git a/onnxruntime/test/testdata/transform/qdq_with_multi_consumer_dq_nodes.fixed.qdq16_contrib.onnx b/onnxruntime/test/testdata/transform/qdq_with_multi_consumer_dq_nodes.fixed.qdq16_contrib.onnx new file mode 100644 index 0000000000000..f71114cf31bf9 Binary files /dev/null and b/onnxruntime/test/testdata/transform/qdq_with_multi_consumer_dq_nodes.fixed.qdq16_contrib.onnx differ diff --git a/orttraining/orttraining/test/training_ops/function_op_test_utils.cc b/orttraining/orttraining/test/training_ops/function_op_test_utils.cc index 5eed4765abfd7..9504ba2c1e69a 100644 --- a/orttraining/orttraining/test/training_ops/function_op_test_utils.cc +++ b/orttraining/orttraining/test/training_ops/function_op_test_utils.cc @@ -25,8 +25,8 @@ void OpFunctionTester::RunFunctionBodyGraphOnCPU(TwoDArray& results) { auto& node = *graph.Nodes().begin(); ASSERT_EQ(node.OpType(), op); - // Inline function will call Resolve itself ASSERT_STATUS_OK(graph.InlineFunction(node)); + ASSERT_STATUS_OK(graph.Resolve()); // Hookup the inputs and outputs std::unordered_map feeds; diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index ec85002503c0e..2e7ac9508a41e 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -73,7 +73,8 @@ stages: project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' definition: '940' specificBuildWithTriggering: true - buildVersionToDownload: 'latest' + buildVersionToDownload: 'latestFromBranch' + branchName: 'refs/heads/main' artifactName: 'NPM_packages' targetPath: '$(Pipeline.Workspace)' displayName: 'Download onnxruntime-node Pipeline Artifact' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml index a81dd1e9cf240..2f67398908d5d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml @@ -12,6 +12,10 @@ parameters: type: string default: "$(Agent.TempDirectory)/ort_ccache" +- name: DebugCache + type: boolean + default: false + - name: AdditionalKey type: string default: "" @@ -45,6 +49,18 @@ parameters: steps: - ${{ if eq(parameters.WithCache, true) }}: + - powershell: | + if ([string]::IsNullOrEmpty((Get-Command ccache -errorAction SilentlyContinue))) + { + choco install ccache -y --version 4.7.4 + $ccache_path = (Get-Command ccache).Source + $ccache_parent_dir = (Split-Path -parent $ccache_path) + Copy-Item "C:\ProgramData\chocolatey\lib\ccache\tools\ccache-4.7.4-windows-x86_64\ccache.exe" -Destination "C:\ProgramData\chocolatey\bin\cl.exe" + Get-ChildItem $ccache_parent_dir + ccache --version + } + displayName: Install ccache + - task: Cache@2 inputs: ${{if eq(variables['Build.SourceBranchName'], 'merge')}}: @@ -83,6 +99,11 @@ steps: createLogFile: true env: CCACHE_DIR: ${{parameters.CacheDir}} + CCACHE_SLOPPINESS: file_macro,time_macros,include_file_mtime,include_file_ctime + CCACHE_COMPILERCHECK: content + ${{if eq(parameters.DebugCache, true)}}: + CCACHE_DEBUG: 1 + CCACHE_DEBUGDIR: $(Agent.TempDirectory)/cache_debug - ${{ if eq(parameters.WithCache, true) }}: - powershell: | @@ -91,3 +112,10 @@ steps: displayName: cache stat env: CCACHE_DIR: ${{parameters.CacheDir}} + + - ${{if eq(parameters.DebugCache, true)}}: + - task: PublishPipelineArtifact@0 + displayName: 'publish cache log' + inputs: + artifactName: 'cache-log' + targetPath: $(Agent.TempDirectory)/cache_debug diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml index e29d9d2cab91c..8868e671a5fa5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml @@ -1,6 +1,7 @@ parameters: - name: EnvSetupScript type: string + default: setup_env.bat - name: BuildConfig type: string diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 7db0c7302cd6f..68e0d51480a63 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -46,7 +46,8 @@ jobs: BuildConfig: 'RelWithDebInfo' ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' QNN_SDK_ROOT: 'C:\data\qnnsdk\${{parameters.QnnSdk}}' - timeoutInMinutes: 150 + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + timeoutInMinutes: 120 workspace: clean: all steps: @@ -70,29 +71,19 @@ jobs: # '_Ret &std::_Visit_strategy<1>::_Visit2<_Ret,_ListOfIndexVectors,_Callable,const # std::variant&>(size_t,_Callable &&, # const std::variant &)' - - task: PythonScript@0 - displayName: 'Generate cmake config' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --compile_no_warning_as_error --update --cmake_generator "Visual Studio 17 2022" --use_qnn --qnn_home $(QNN_SDK_ROOT) --parallel' - workingDirectory: '$(Build.BinariesDirectory)' - - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\$(BuildConfig)\onnxruntime.sln' - platform: 'x64' - configuration: $(BuildConfig) - msbuildArgs: $(MsbuildArguments) - msbuildArchitecture: $(buildArch) - maximumCpuCount: true - logProjectEvents: false - workingFolder: '$(Build.BinariesDirectory)\$(BuildConfig)' - createLogFile: true + - template: templates/jobs/win-ci-build-steps.yml + parameters: + WithCache: True + Today: $(TODAY) + AdditionalKey: "win-qnn | $(BuildConfig)" + BuildPyArguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --compile_no_warning_as_error --update --cmake_generator "Visual Studio 17 2022" --use_qnn --qnn_home $(QNN_SDK_ROOT) --parallel' + MsbuildArguments: $(MsbuildArguments) + BuildArch: $(buildArch) + Platform: 'x64' + BuildConfig: $(BuildConfig) - powershell: | python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --test --cmake_generator "Visual Studio 17 2022" --enable_onnx_tests - workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' displayName: 'Run unit tests' diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh index b2e575f38e425..3bca6413100a2 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh @@ -20,6 +20,10 @@ export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" /opt/python/cp39-cp39/bin/python3.9 -m pip install transformers +# beartype is installed here so that onnxscript installation step won't +# install a version PyTorch doesn't like. Once beartype fixes this problem. +# We can remove this line. +/opt/python/cp39-cp39/bin/python3.9 -m pip install beartype==0.15.0 cd /usr/local/ echo "Cloning ONNX Script" diff --git a/tools/python/util/logger.py b/tools/python/util/logger.py index 9deb4475721ee..15e04528ac7ac 100644 --- a/tools/python/util/logger.py +++ b/tools/python/util/logger.py @@ -5,6 +5,7 @@ def get_logger(name): - logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) - - return logging.getLogger(name) + logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s") + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + return logger