Skip to content

Commit

Permalink
Merge branch 'main' into chi/trt_plugin_python
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms authored Sep 19, 2023
2 parents 7afe24e + d522cc7 commit fa49c53
Show file tree
Hide file tree
Showing 85 changed files with 5,036 additions and 2,911 deletions.
3 changes: 3 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ set(contrib_ops_excluded_files
"bert/attention_impl.cu"
"bert/attention_softmax.h"
"bert/attention_softmax.cu"
"bert/attention_prepare_qkv.cu"
"bert/decoder_attention_impl.h"
"bert/decoder_attention_impl.cu"
"bert/decoder_masked_multihead_attention.h"
"bert/decoder_masked_multihead_attention.cc"
"bert/decoder_masked_self_attention.h"
Expand Down
13 changes: 7 additions & 6 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1351,8 +1351,8 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(int8), tensor(uint8), tensor(int32)</dt>
<dd>Constrain 'x' and 'x_zero_point' to 8-bit integer tensors or 32-bit signed integer tensors.</dd>
<dt><tt>T1</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32)</dt>
<dd>Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, 16-bit integer tensors, or 32-bit signed integer tensors.</dd>
<dt><tt>T2</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain 'y', 'x_scale' to float tensors.</dd>
</dl>
Expand Down Expand Up @@ -4194,8 +4194,9 @@ This version of the operator has been available since version 1 of the 'com.micr
### <a name="com.microsoft.QuantizeLinear"></a><a name="com.microsoft.quantizelinear">**com.microsoft.QuantizeLinear**</a>

The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor.
The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8.
For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, [-128, 127] if it's int8,
[0, 65,535] if it's uint16, and [-32,768, 32,767] if it's int16. For (x / y_scale), it's rounding to nearest ties to even.
Refer to https://en.wikipedia.org/wiki/Rounding for details.
Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').

#### Version
Expand Down Expand Up @@ -4232,8 +4233,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>T1</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain 'x', 'y_scale' to float tensors.</dd>
<dt><tt>T2</tt> : tensor(int8), tensor(uint8)</dt>
<dd>Constrain 'y_zero_point' and 'y' to 8-bit integer tensors.</dd>
<dt><tt>T2</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16)</dt>
<dd>Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.</dd>
</dl>


Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ Do not modify directly.*
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|CropAndResize|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *in* crop_size:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int32)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int32), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)<br/> **T2** = tensor(float)|
|DynamicQuantizeLSTM|*in* X:**T**<br> *in* W:**T2**<br> *in* R:**T2**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* W_scale:**T**<br> *in* W_zero_point:**T2**<br> *in* R_scale:**T**<br> *in* R_zero_point:**T2**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(float)<br/> **T1** = tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|DynamicQuantizeMatMul|*in* A:**T1**<br> *in* B:**T2**<br> *in* b_scale:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float)|
Expand Down Expand Up @@ -472,7 +472,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**<br> *in* X:**T**<br> *in* x_scale:**TF**<br> *in* x_zero_point:**T**<br> *in* Y:**T**<br> *in* y_scale:**TF**<br> *in* y_zero_point:**T**<br> *in* z_scale:**TF**<br> *in* z_zero_point:**T**<br> *out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,7 @@ class Graph {

/**
Directly insert the nodes in the function Node provided into this Graph.
The Graph needs to be Resolve()d after this call.
@param node Node with Node::Type of Node::Type::Fused
@returns Status indicating success or providing an error message.
*/
Expand Down
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
| LeakyRelu | ai.onnx(6-15,16+) | |
Expand Down
32 changes: 19 additions & 13 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';

type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
Expand Down Expand Up @@ -101,6 +101,9 @@ export const parseCastAttributes = (attributes: Record<string, unknown>): CastAt
export const cast = (context: ComputeContext, attributes: CastAttributes): void => {
let func: ElementwiseFunctionCall;
switch (attributes.to) {
case DataType.float16:
func = 'vec4<f16>';
break;
case DataType.float:
func = 'vec4<f32>';
break;
Expand All @@ -126,11 +129,12 @@ export interface ClipAttributes extends AttributeWithCacheKey {
}

export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfoLoader(
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
const clip_min_: vec4<f32> = vec4(f32(${attributes.min}));
const clip_max_: vec4<f32> = vec4(f32(${attributes.max}));
const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min}));
const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max}));
`,
attributes.cacheKey),
{inputs: [0]});
Expand Down Expand Up @@ -180,13 +184,13 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
attributes.cacheKey));
};

export const erfImpl = (dataType: string) => `
const r0: f32 = 0.3275911;
const r1: f32 = 0.254829592;
const r2: f32 = -0.284496736;
const r3: f32 = 1.421413741;
const r4: f32 = -1.453152027;
const r5: f32 = 1.061405429;
export const erfImpl = (dataType: string, varType = 'f32') => `
const r0: ${varType} = 0.3275911;
const r1: ${varType} = 0.254829592;
const r2: ${varType} = -0.284496736;
const r3: ${varType} = 1.421413741;
const r4: ${varType} = -1.453152027;
const r5: ${varType} = 1.061405429;
fn erf_vf32(v: ${dataType}) -> ${dataType} {
let absv = abs(v);
Expand All @@ -195,8 +199,9 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
}`;

export const erf = (context: ComputeContext): void => {
context.compute(
createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4<f32>')));
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
};

export const exp = (context: ComputeContext): void => {
Expand All @@ -208,9 +213,10 @@ export const floor = (context: ComputeContext): void => {
};

export const gelu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
erfImpl('vec4<f32>')));
erfImpl(`vec4<${dataType}>`, dataType)));
};

export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
Expand Down
3 changes: 1 addition & 2 deletions js/web/script/test-runner-cli-args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
const globalEnvFlags = parseGlobalEnvFlags(args);

if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) {
// Backend webnn is restricted in the dedicated worker.
globalEnvFlags.wasm!.proxy = true;
throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.');
}

// Options:
Expand Down
6 changes: 4 additions & 2 deletions js/web/script/test-runner-cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ async function main() {
.flat();

for (const backend of DEFAULT_BACKENDS) {
nodeTests.set(backend, loadNodeTests(backend, allNodeTestsFolders));
opTests.set(backend, loadOpTests(backend));
if (args.backends.indexOf(backend) !== -1) {
nodeTests.set(backend, loadNodeTests(backend, allNodeTestsFolders));
opTests.set(backend, loadOpTests(backend));
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,11 @@
// // "test_hardsigmoid",
// // "test_hardswish_expanded",
// // "test_hardswish",
"test_if",
// TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra
// supports Sequence and Optional types
// "test_if_seq",
// "test_if_opt",
"test_instancenorm_epsilon",
"test_instancenorm_example",
// "test_isinf_negative",
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLine
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid);
Expand Down Expand Up @@ -191,9 +195,13 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid)>,
Expand Down
56 changes: 0 additions & 56 deletions onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc

This file was deleted.

34 changes: 16 additions & 18 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,30 +249,28 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
AttentionData<CudaT> data;
data.gemm_buffer = reinterpret_cast<CudaT*>(gemm_buffer.get());
data.bias = nullptr == bias ? nullptr : reinterpret_cast<const CudaT*>(bias->Data<T>());
data.query = nullptr;
data.key = nullptr;
data.value = nullptr;
data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data<int>();
data.mask_index_dims = (nullptr == mask_index) ? gsl::span<const int64_t>() : mask_index->Shape().GetDims();
data.past = (nullptr == past) ? nullptr : reinterpret_cast<const CudaT*>(past->Data<T>());
data.past_key = nullptr;
data.past_value = nullptr;
data.relative_position_bias = (nullptr == relative_position_bias)
? nullptr
: reinterpret_cast<const CudaT*>(relative_position_bias->Data<T>());
if (nullptr != bias) {
data.bias = reinterpret_cast<const CudaT*>(bias->Data<T>());
}
if (nullptr != mask_index) {
data.mask_index = mask_index->Data<int>();
data.mask_index_dims = mask_index->Shape().GetDims();
}
if (nullptr != past) {
data.past = reinterpret_cast<const CudaT*>(past->Data<T>());
}
if (nullptr != relative_position_bias) {
data.relative_position_bias = reinterpret_cast<const CudaT*>(relative_position_bias->Data<T>());
}
data.has_qkv_workspace = true;
data.workspace = reinterpret_cast<CudaT*>(work_space.get());
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
data.present_key = nullptr;
data.present_value = nullptr;
if (nullptr != present) {
data.present = reinterpret_cast<CudaT*>(present->MutableData<T>());
}
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.fused_cross_attention_kernel = nullptr;
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
data.cumulated_sequence_length_q_cache = nullptr;
data.cumulated_sequence_length_kv_cache = nullptr;

return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
}
Expand Down
Loading

0 comments on commit fa49c53

Please sign in to comment.