diff --git a/.gitmodules b/.gitmodules
index 036a248070855..7bb49e98bfec1 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -8,6 +8,3 @@
path = cmake/external/emsdk
url = https://github.com/emscripten-core/emsdk.git
branch = 3.1.44
-[submodule "cmake/external/onnxruntime-extensions"]
- path = cmake/external/onnxruntime-extensions
- url = https://github.com/microsoft/onnxruntime-extensions.git
diff --git a/VERSION_NUMBER b/VERSION_NUMBER
index 15b989e398fc7..092afa15df4df 100644
--- a/VERSION_NUMBER
+++ b/VERSION_NUMBER
@@ -1 +1 @@
-1.16.0
+1.17.0
diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index 25202f82f468d..cf71b6bcf7c7d 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -10,6 +10,9 @@ set(contrib_ops_excluded_files
"bert/attention_impl.cu"
"bert/attention_softmax.h"
"bert/attention_softmax.cu"
+ "bert/attention_prepare_qkv.cu"
+ "bert/decoder_attention_impl.h"
+ "bert/decoder_attention_impl.cu"
"bert/decoder_masked_multihead_attention.h"
"bert/decoder_masked_multihead_attention.cc"
"bert/decoder_masked_self_attention.h"
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
index d6341b90f28ff..68a399f8b9671 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
@@ -65,10 +65,10 @@ static NativeTrainingMethods()
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi));
// TODO: Make this save the pointer, and not copy the whole structure across
- api_ = (OrtApi)OrtGetApi(16 /*ORT_API_VERSION*/);
+ api_ = (OrtApi)OrtGetApi(17 /*ORT_API_VERSION*/);
OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi));
- trainingApiPtr = OrtGetTrainingApi(16 /*ORT_API_VERSION*/);
+ trainingApiPtr = OrtGetTrainingApi(17 /*ORT_API_VERSION*/);
if (trainingApiPtr != IntPtr.Zero)
{
trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi));
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 5bd1a89c0dea1..95dc8c3cde46c 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1351,8 +1351,8 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T1 : tensor(int8), tensor(uint8), tensor(int32)
-- Constrain 'x' and 'x_zero_point' to 8-bit integer tensors or 32-bit signed integer tensors.
+- T1 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32)
+- Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, 16-bit integer tensors, or 32-bit signed integer tensors.
- T2 : tensor(float16), tensor(float)
- Constrain 'y', 'x_scale' to float tensors.
@@ -4194,8 +4194,9 @@ This version of the operator has been available since version 1 of the 'com.micr
### **com.microsoft.QuantizeLinear**
The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor.
- The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8.
- For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
+ The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, [-128, 127] if it's int8,
+ [0, 65,535] if it's uint16, and [-32,768, 32,767] if it's int16. For (x / y_scale), it's rounding to nearest ties to even.
+ Refer to https://en.wikipedia.org/wiki/Rounding for details.
Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').
#### Version
@@ -4232,8 +4233,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- T1 : tensor(float16), tensor(float)
- Constrain 'x', 'y_scale' to float tensors.
-- T2 : tensor(int8), tensor(uint8)
-- Constrain 'y_zero_point' and 'y' to 8-bit integer tensors.
+- T2 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16)
+- Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index d46f3ed9bd262..33c187a28b62e 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -439,7 +439,7 @@ Do not modify directly.*
|CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)|
-|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float)|
+|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)
**T2** = tensor(float)|
|DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
|EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float)|
@@ -472,7 +472,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
-|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
+|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
diff --git a/docs/c_cxx/doxygen-header.html b/docs/c_cxx/doxygen-header.html
index 364f76f7f0580..6d95bf57ff98f 100644
--- a/docs/c_cxx/doxygen-header.html
+++ b/docs/c_cxx/doxygen-header.html
@@ -16,7 +16,7 @@
-
+
$treeview
$search
$mathjax
diff --git a/docs/python/README.rst b/docs/python/README.rst
index 7d978b0941235..32bb3729e01d0 100644
--- a/docs/python/README.rst
+++ b/docs/python/README.rst
@@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime
+ new (tensorTypeToTypedArrayConstructor(type))(dataBuffer);
+
+/**
+ * a TensorView does not own the data.
+ */
+export interface TensorView {
+ readonly data: number;
+ readonly dataType: number;
+ readonly dims: readonly number[];
+
+ /**
+ * get a Float32Array data view of the tensor data. tensor data must be on CPU.
+ */
+ getFloat32Array(): Float32Array;
+
+ /**
+ * get a BigInt64Array data view of the tensor data. tensor data must be on CPU.
+ */
+ getBigInt64Array(): BigInt64Array;
+
+ /**
+ * get a Int32Array data view of the tensor data. tensor data must be on CPU.
+ */
+ getInt32Array(): Int32Array;
+
+ /**
+ * create a new tensor view with the same data but different dimensions.
+ */
+ reshape(newDims: readonly number[]): TensorView;
+}
diff --git a/js/web/lib/wasm/jsep/tensor.ts b/js/web/lib/wasm/jsep/tensor.ts
deleted file mode 100644
index abe61e07fc0a8..0000000000000
--- a/js/web/lib/wasm/jsep/tensor.ts
+++ /dev/null
@@ -1,115 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-export declare namespace Tensor {
- export interface DataTypeMap {
- bool: Uint8Array;
- float32: Float32Array;
- float64: Float64Array;
- string: string[];
- int8: Int8Array;
- uint8: Uint8Array;
- int16: Int16Array;
- uint16: Uint16Array;
- int32: Int32Array;
- uint32: Uint32Array;
- int64: BigInt64Array;
- uint64: BigUint64Array;
- }
-
- export type DataType = keyof DataTypeMap;
-
- export type StringType = Tensor.DataTypeMap['string'];
- export type BooleanType = Tensor.DataTypeMap['bool'];
- export type IntegerType = Tensor.DataTypeMap['int8']|Tensor.DataTypeMap['uint8']|Tensor.DataTypeMap['int16']|
- Tensor.DataTypeMap['uint16']|Tensor.DataTypeMap['int32']|Tensor.DataTypeMap['uint32']|
- Tensor.DataTypeMap['int64']|Tensor.DataTypeMap['uint64'];
- export type FloatType = Tensor.DataTypeMap['float32']|Tensor.DataTypeMap['float64'];
- export type NumberType = BooleanType|IntegerType|FloatType;
-
- export type Id = number;
-}
-
-export const sizeof = (type: Tensor.DataType): number => {
- switch (type) {
- case 'bool':
- case 'int8':
- case 'uint8':
- return 1;
- case 'int16':
- case 'uint16':
- return 2;
- case 'int32':
- case 'uint32':
- case 'float32':
- return 4;
- case 'int64':
- case 'uint64':
- case 'float64':
- return 8;
- default:
- throw new Error(`cannot calculate sizeof() on type ${type}`);
- }
-};
-
-const dataviewConstructor = (type: Tensor.DataType) => {
- switch (type) {
- case 'bool':
- case 'uint8':
- return Uint8Array;
- case 'int8':
- return Int8Array;
- case 'int16':
- return Int16Array;
- case 'uint16':
- return Uint16Array;
- case 'int32':
- return Int32Array;
- case 'uint32':
- return Uint32Array;
- case 'int64':
- return BigInt64Array;
- case 'uint64':
- return BigUint64Array;
- case 'float32':
- return Float32Array;
- case 'float64':
- return Float64Array;
- default:
- // should never run to here
- throw new Error('unspecified error');
- }
-};
-
-export const createView = (dataBuffer: ArrayBuffer, type: Tensor.DataType): Int32Array|Uint32Array|BigInt64Array|
- BigUint64Array|Uint8Array|Float32Array|Float64Array|Int8Array|Int16Array|Uint16Array =>
- new (dataviewConstructor(type))(dataBuffer);
-
-/**
- * a TensorView does not own the data.
- */
-export interface TensorView {
- readonly data: number;
- readonly dataType: number;
- readonly dims: readonly number[];
-
- /**
- * get a Float32Array data view of the tensor data. tensor data must be on CPU.
- */
- getFloat32Array(): Float32Array;
-
- /**
- * get a BigInt64Array data view of the tensor data. tensor data must be on CPU.
- */
- getBigInt64Array(): BigInt64Array;
-
- /**
- * get a Int32Array data view of the tensor data. tensor data must be on CPU.
- */
- getInt32Array(): Int32Array;
-
- /**
- * create a new tensor view with the same data but different dimensions.
- */
- reshape(newDims: readonly number[]): TensorView;
-}
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
index 02507ad802b36..08b1d1f30b233 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
@@ -20,7 +20,7 @@
// modified to fit the needs of the project
import {LOG_DEBUG} from '../../../log';
-import {TensorView} from '../../../tensor';
+import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types';
import {ConvAttributes} from '../conv';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
index 82fe3d5b6af43..ec6df438129fb 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
@@ -18,7 +18,7 @@
// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts
import {LOG_DEBUG} from '../../../log';
-import {TensorView} from '../../../tensor';
+import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types';
import {inputVariable, outputVariable, ShaderHelper} from '../common';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
index ab4f608451101..8d43dbb378a69 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
@@ -19,7 +19,7 @@
//
// modified to fit the needs of the project
-import {TensorView} from '../../../tensor';
+import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types';
import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from '../common';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts
index 12a13d9d8e0a0..412e61a3cc0f9 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts
@@ -6,7 +6,7 @@
// a optimized codepath for this.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
index b004ca37a2ea8..13d3a91bb339e 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
index f3845e3110905..c054da51a3098 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
@@ -592,7 +592,8 @@ class ShaderHelperImpl implements ShaderHelper {
const workgroupSizeZ = typeof workgroupSize === 'number' ? 1 : workgroupSize[2];
const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1;
- const paramList = is1DimensionDispatch ? '@builtin(global_invocation_id) global_id : vec3' :
+ const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3,
+ @builtin(local_invocation_id) local_id : vec3` :
`@builtin(local_invocation_index) local_index : u32,
@builtin(workgroup_id) workgroup_id : vec3`;
const globalIdxDefinition = is1DimensionDispatch ?
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
index 9b294803d3787..279632c190ded 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
index 8a794ce16a0b5..1b7b7e0b29a25 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
index acdfd7e40f258..e7d1ddf771650 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
index 3a83b1c2de6c1..95a64e5787841 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {PoolConvUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts
index 0abece9559630..21c0b97042fbb 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types';
import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
index f0196f37c3153..fc9ebf004ad25 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts
index 2d845775f1c62..824ce682c0c4b 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts
index 57c5fccfd8c26..a7d355bc13704 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
index a915a4bbd969c..0db060dbec54a 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
index 3ce963b54f3ee..1a36d4a7545d6 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {GemmUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
index f62c766aa9ed0..5a148bda0a9f7 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
@@ -1,83 +1,97 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
-import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
+import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
export interface InstanceNormAttributes extends AttributeWithCacheKey {
epsilon: number;
format: 'NHWC'|'NCHW';
}
-const validateInputs = (inputs: readonly TensorView[]): void => {
- if (!inputs || inputs.length !== 3) {
- throw new Error('instanceNorm requires 3 inputs.');
- }
-
- if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) {
- throw new Error('inputs should be float type');
- }
-};
-
const createInstanceNormProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => {
const xShape = inputs[0].dims;
- const scale = inputs[1];
- const bias = inputs[2];
const outputShape = xShape;
- const outputSize = ShapeUtil.size(outputShape);
const axis = 2;
const normCount = ShapeUtil.sizeToDimension(xShape, axis);
const normSize = ShapeUtil.sizeFromDimension(xShape, axis);
const C = xShape[1];
-
- const scaleSize = ShapeUtil.size(scale.dims);
- const biasSize = bias ? ShapeUtil.size(bias.dims) : 0;
- if (scaleSize !== normSize || (bias && biasSize !== normSize)) {
- throw new Error(`Size of X.shape()[axis:] == ${normSize}.
- Size of scale and bias (if provided) must match this.
- Got scale size of ${scaleSize} and bias size of ${biasSize}`);
- }
-
- const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
-
+ const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normSize]);
+ const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims);
+ const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims);
+ const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normSize]);
+ const variables = [x, scale, bias, output];
+ const dataType = x.type.value;
+ const workgroupSize = 64;
const getShaderSource = (shaderHelper: ShaderHelper) => `
+
const C: u32 = ${C};
const normSize: u32 = ${normSize};
- const normSizeTyped: ${dataType} = ${normSize};
const epsilon: f32 = ${attributes.epsilon};
+ var meanShared : ${dataType};
+ var squaredNormShared : ${dataType};
+ var workgroupShared : array<${dataType}, ${workgroupSize}>;
+ const workgroupSize = ${workgroupSize}u;
+ ${shaderHelper.declareVariables(...variables)}
+ ${shaderHelper.mainStart(workgroupSize)}
+ let norm = global_idx / workgroupSize;
+ let batch = norm / C;
+ let channel = norm % C;
+ let localIndex = local_id.x;
+
+ // initialize workgroup memory
+ var initial: ${dataType} = 0;
+ for (var h = localIndex; h < normSize; h += workgroupSize) {
+ initial = initial + ${x.get('batch', 'channel', 'h')};
+ }
+ workgroupShared[localIndex] = initial;
+ workgroupBarrier();
- @group(0) @binding(0) var x : array<${dataType}>;
- @group(0) @binding(1) var scale : array<${dataType}>;
- @group(0) @binding(2) var bias : array<${dataType}>;
- @group(0) @binding(3) var output : array<${dataType}>;
-
- ${shaderHelper.mainStart()}
- let offset = global_idx * normSize;
- if (offset + normSize >= ${outputSize}) { return; }
- var mean: ${dataType} = 0;
+ // Calculate the mean of current channel data.
+ for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) {
+ if (localIndex < currSize) {
+ workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
+ }
+ workgroupBarrier();
+ }
+ if (localIndex == 0) {
+ meanShared = workgroupShared[0] / ${dataType}(normSize);
+ }
+ workgroupBarrier();
- for (var h: u32 = 0u; h < normSize; h++) {
- mean = mean + x[h + offset];
+ // reinitialize workgroup memory.
+ initial = 0;
+ for (var h = localIndex; h < normSize; h += workgroupSize) {
+ let deviation = ${x.get('batch', 'channel', 'h')} - meanShared;
+ initial = initial + deviation * deviation;
}
- mean = mean / normSizeTyped;
+ workgroupShared[localIndex] = initial;
+ workgroupBarrier();
- var squaredNorm: ${dataType} = 0;
- for (var h: u32 = 0u; h < normSize; h++) {
- let deviation: f32 = x[h + offset] - mean;
- squaredNorm = squaredNorm + deviation * deviation;
+ // Calculate the sum of square of deviation of current channel data.
+ for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) {
+ if (localIndex < currSize) {
+ workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
+ }
+ workgroupBarrier();
}
- let invStdDev = 1 / sqrt(squaredNorm / normSizeTyped + epsilon);
- let channelScale = invStdDev * scale[global_idx % C];
- let channelShift = bias[global_idx % C] - mean * channelScale;
- for (var j: u32 = 0; j < normSize; j++) {
- output[j + offset] = x[j + offset] * channelScale + channelShift;
+ if (localIndex == 0) {
+ squaredNormShared = workgroupShared[0];
+ }
+ workgroupBarrier();
+
+ let invStdDev = 1 / sqrt(squaredNormShared / ${dataType}(normSize) + epsilon);
+ let channelScale = invStdDev * ${scale.getByOffset('channel')};
+ let channelShift = ${bias.getByOffset('channel')} - meanShared * channelScale;
+ for (var h = localIndex; h < normSize; h += workgroupSize) {
+ let value = ${x.get('batch', 'channel', 'h')} * channelScale + channelShift;
+ ${output.set('batch', 'channel', 'h', 'value')};
}
}`;
return {
@@ -86,7 +100,7 @@ const createInstanceNormProgramInfo =
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
],
getShaderSource,
- dispatchGroup: () => ({x: Math.ceil(normCount / 64 /* workgroup size */)})
+ dispatchGroup: () => ({x: normCount})
};
};
@@ -118,7 +132,7 @@ const createInstanceNormNHWCProgramInfo =
${shaderHelper.mainStart()}
let currentImageNumber = global_idx / C;
let currentChannelNumber = global_idx % C;
-
+
// offset is channel num * N
let offset = currentImageNumber * imageSize;
if (offset >= ${outputSize}) { return; }
@@ -156,8 +170,6 @@ export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes):
createAttributeWithCacheKey({epsilon: attributes.epsilon, format: attributes.format});
export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => {
- validateInputs(context.inputs);
-
const metadata = {
name: 'InstanceNormalization',
inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
index 8a9927b25a52e..d6a79e9460c3f 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
index e4dae00db6305..837ac8410f291 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {BroadcastUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts
index d90296b5c5a46..c2f89fd2845df 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
index 79071d32443d6..8c8c12fc54ddb 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {PoolConvUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
index cb592c838dd97..0b8d03ea73b6b 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
index 1d0b8229a76f7..8b9dbbf57ac75 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
index 4b845bcf2121b..7bfdd73b8af18 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts
index 4211e526898e6..257b9ebc1fdac 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata, TensorInfo} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
index e2443b24410a5..495a4bcea4f47 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
@@ -5,7 +5,7 @@
// performance limitations when the reduced axis is long. Need to add
// a optimized codepath for this.
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
index 9a150d21ea02e..3367091bbac23 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata, TensorInfo} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts
index 99d9668757caa..109c29bfc8a80 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
index 9243b0e4af6b6..38dcaeab54c54 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo} from '../types';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
index ef63d1177768c..f08d7a77d1099 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
@@ -2,12 +2,12 @@
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
-import {TensorView} from '../../tensor';
+import {TensorView} from '../../tensor-view';
import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
-import {inputVariable, outputVariable, ShaderHelper} from './common';
+import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
@@ -101,6 +101,9 @@ export const parseCastAttributes = (attributes: Record): CastAt
export const cast = (context: ComputeContext, attributes: CastAttributes): void => {
let func: ElementwiseFunctionCall;
switch (attributes.to) {
+ case DataType.float16:
+ func = 'vec4';
+ break;
case DataType.float:
func = 'vec4';
break;
@@ -126,11 +129,12 @@ export interface ClipAttributes extends AttributeWithCacheKey {
}
export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
+ const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfoLoader(
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
- const clip_min_: vec4 = vec4(f32(${attributes.min}));
- const clip_max_: vec4 = vec4(f32(${attributes.max}));
+ const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min}));
+ const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max}));
`,
attributes.cacheKey),
{inputs: [0]});
@@ -180,13 +184,13 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
attributes.cacheKey));
};
-export const erfImpl = (dataType: string) => `
-const r0: f32 = 0.3275911;
-const r1: f32 = 0.254829592;
-const r2: f32 = -0.284496736;
-const r3: f32 = 1.421413741;
-const r4: f32 = -1.453152027;
-const r5: f32 = 1.061405429;
+export const erfImpl = (dataType: string, varType = 'f32') => `
+const r0: ${varType} = 0.3275911;
+const r1: ${varType} = 0.254829592;
+const r2: ${varType} = -0.284496736;
+const r3: ${varType} = 1.421413741;
+const r4: ${varType} = -1.453152027;
+const r5: ${varType} = 1.061405429;
fn erf_vf32(v: ${dataType}) -> ${dataType} {
let absv = abs(v);
@@ -195,8 +199,9 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
}`;
export const erf = (context: ComputeContext): void => {
- context.compute(
- createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4')));
+ const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
+ context.compute(createElementwiseProgramInfoLoader(
+ context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
};
export const exp = (context: ComputeContext): void => {
@@ -208,9 +213,10 @@ export const floor = (context: ComputeContext): void => {
};
export const gelu = (context: ComputeContext): void => {
+ const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
- erfImpl('vec4')));
+ erfImpl(`vec4<${dataType}>`, dataType)));
};
export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts
index cce61be3448cd..cf2687e4c7382 100644
--- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts
+++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts
@@ -4,7 +4,7 @@
import {tensorDataTypeEnumToString} from '../../wasm-common';
import {WebGpuBackend} from '../backend-webgpu';
import {LOG_DEBUG} from '../log';
-import {TensorView} from '../tensor';
+import {TensorView} from '../tensor-view';
import {createShaderHelper} from './ops/common';
import {Artifact, GpuData, ProgramInfo} from './types';
diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts
index ddbb9afc275f2..78f80b89774e2 100644
--- a/js/web/lib/wasm/jsep/webgpu/types.ts
+++ b/js/web/lib/wasm/jsep/webgpu/types.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {Tensor, TensorView} from '../tensor';
+import {TensorView} from '../tensor-view';
import {ShaderHelper} from './ops/common';
@@ -19,7 +19,6 @@ export interface GpuData {
}
export interface TensorInfo {
- id?: Tensor.Id;
dims: readonly number[];
dataType: number;
gpuDataType: GpuDataType;
diff --git a/js/web/package-lock.json b/js/web/package-lock.json
index eabd641914170..9567bc172c9ed 100644
--- a/js/web/package-lock.json
+++ b/js/web/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-web",
- "version": "1.16.0",
+ "version": "1.17.0",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-web",
- "version": "1.16.0",
+ "version": "1.17.0",
"license": "MIT",
"dependencies": {
"flatbuffers": "^1.12.0",
@@ -49,7 +49,7 @@
},
"../common": {
"name": "onnxruntime-common",
- "version": "1.16.0",
+ "version": "1.17.0",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.23.22"
diff --git a/js/web/package.json b/js/web/package.json
index ce06475f672fd..8ae5b733e5f21 100644
--- a/js/web/package.json
+++ b/js/web/package.json
@@ -8,7 +8,7 @@
"type": "git"
},
"author": "fs-eire",
- "version": "1.16.0",
+ "version": "1.17.0",
"jsdelivr": "dist/ort.min.js",
"dependencies": {
"flatbuffers": "^1.12.0",
diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts
index 7b41850948149..f90f568879146 100644
--- a/js/web/script/test-runner-cli-args.ts
+++ b/js/web/script/test-runner-cli-args.ts
@@ -382,8 +382,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
const globalEnvFlags = parseGlobalEnvFlags(args);
if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) {
- // Backend webnn is restricted in the dedicated worker.
- globalEnvFlags.wasm!.proxy = true;
+ throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.');
}
// Options:
diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts
index 520ef62b2c719..a75321d45f1ef 100644
--- a/js/web/script/test-runner-cli.ts
+++ b/js/web/script/test-runner-cli.ts
@@ -84,8 +84,10 @@ async function main() {
.flat();
for (const backend of DEFAULT_BACKENDS) {
- nodeTests.set(backend, loadNodeTests(backend, allNodeTestsFolders));
- opTests.set(backend, loadOpTests(backend));
+ if (args.backends.indexOf(backend) !== -1) {
+ nodeTests.set(backend, loadNodeTests(backend, allNodeTestsFolders));
+ opTests.set(backend, loadOpTests(backend));
+ }
}
}
diff --git a/js/web/test/data/ops/instance-norm.jsonc b/js/web/test/data/ops/instance-norm.jsonc
new file mode 100644
index 0000000000000..6a4e6912405ee
--- /dev/null
+++ b/js/web/test/data/ops/instance-norm.jsonc
@@ -0,0 +1,79 @@
+[
+ {
+ "name": "Simple test with NHWC",
+ "operator": "InstanceNormalization",
+ "inputShapeDefinitions": "rankOnly",
+ "opset": { "domain": "", "version": 17 },
+ "cases": [
+ {
+ "name": "Simple test",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4],
+ "dims": [1, 4, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [4],
+ "type": "float32"
+ },
+ {
+ "data": [4, 5, 6, 7],
+ "dims": [4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 2.6583645343780518, 3.552788257598877, 4.447211742401123, 5.341635704040527, 2.3167295455932617,
+ 4.105576515197754, 5.8944244384765625, 7.683271408081055, 6, 10.242595672607422, 6, 1.7574005126953125,
+ 12.36654281616211, 8.788846969604492, 5.211153030395508, 1.633458137512207
+ ],
+ "dims": [1, 4, 2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "Simple test with NCHW",
+ "operator": "InstanceNormalization",
+ "opset": { "domain": "", "version": 17 },
+ "cases": [
+ {
+ "name": "Simple test",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4],
+ "dims": [1, 4, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [4],
+ "type": "float32"
+ },
+ {
+ "data": [4, 5, 6, 7],
+ "dims": [4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 2.6583645343780518, 3.552788257598877, 4.447211742401123, 5.341635704040527, 2.3167295455932617,
+ 4.105576515197754, 5.8944244384765625, 7.683271408081055, 6, 10.242595672607422, 6, 1.7574005126953125,
+ 12.36654281616211, 8.788846969604492, 5.211153030395508, 1.633458137512207
+ ],
+ "dims": [1, 4, 2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ }
+]
diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc
index e580259071968..6e65645ef4756 100644
--- a/js/web/test/suite-test-list.jsonc
+++ b/js/web/test/suite-test-list.jsonc
@@ -257,6 +257,7 @@
"greater.jsonc",
//"identity.jsonc",
"image-scaler.jsonc",
+ "instance-norm.jsonc",
"less.jsonc",
"log.jsonc",
"matmul.jsonc",
@@ -601,6 +602,11 @@
// // "test_hardsigmoid",
// // "test_hardswish_expanded",
// // "test_hardswish",
+ "test_if",
+ // TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra
+ // supports Sequence and Optional types
+ // "test_if_seq",
+ // "test_if_opt",
"test_instancenorm_epsilon",
"test_instancenorm_example",
// "test_isinf_negative",
@@ -1347,6 +1353,7 @@
"gemm.jsonc",
"global-average-pool.jsonc",
"greater.jsonc",
+ "instance-norm.jsonc",
"less.jsonc",
"log.jsonc",
"matmul.jsonc",
diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py
index d39d8edf0b73a..fd147eaa11f3f 100644
--- a/onnxruntime/__init__.py
+++ b/onnxruntime/__init__.py
@@ -7,7 +7,7 @@
For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_
or the `Github project `_.
"""
-__version__ = "1.16.0"
+__version__ = "1.17.0"
__author__ = "Microsoft"
# we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 660c8bd9e0624..0ec5088808656 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -56,9 +56,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLine
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid);
@@ -191,9 +195,13 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc b/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc
deleted file mode 100644
index 28a304bfc7f0e..0000000000000
--- a/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc
+++ /dev/null
@@ -1,56 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "core/providers/cpu/quantization/quantize_linear.h"
-#include "core/providers/common.h"
-
-namespace onnxruntime {
-namespace contrib {
-
-ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
- DequantizeLinear,
- 1,
- uint8_t,
- KernelDefBuilder()
- .TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
- DequantizeLinear);
-
-ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
- DequantizeLinear,
- 1,
- int8_t,
- KernelDefBuilder()
- .TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
- DequantizeLinear);
-
-ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
- DequantizeLinear,
- 1,
- int32_t,
- KernelDefBuilder()
- .TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
- DequantizeLinear);
-
-ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
- QuantizeLinear,
- 1,
- uint8_t,
- KernelDefBuilder()
- .TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
- QuantizeLinear);
-
-ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
- QuantizeLinear,
- 1,
- int8_t,
- KernelDefBuilder()
- .TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
- QuantizeLinear);
-
-} // namespace contrib
-} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc
index a79ad96b94d91..f0385ea5abdfb 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc
@@ -249,30 +249,28 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType::MappedType CudaT;
AttentionData data;
data.gemm_buffer = reinterpret_cast(gemm_buffer.get());
- data.bias = nullptr == bias ? nullptr : reinterpret_cast(bias->Data());
- data.query = nullptr;
- data.key = nullptr;
- data.value = nullptr;
- data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data();
- data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims();
- data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data());
- data.past_key = nullptr;
- data.past_value = nullptr;
- data.relative_position_bias = (nullptr == relative_position_bias)
- ? nullptr
- : reinterpret_cast(relative_position_bias->Data());
+ if (nullptr != bias) {
+ data.bias = reinterpret_cast(bias->Data());
+ }
+ if (nullptr != mask_index) {
+ data.mask_index = mask_index->Data();
+ data.mask_index_dims = mask_index->Shape().GetDims();
+ }
+ if (nullptr != past) {
+ data.past = reinterpret_cast(past->Data());
+ }
+ if (nullptr != relative_position_bias) {
+ data.relative_position_bias = reinterpret_cast(relative_position_bias->Data());
+ }
data.has_qkv_workspace = true;
data.workspace = reinterpret_cast(work_space.get());
data.output = reinterpret_cast(output->MutableData());
- data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData());
- data.present_key = nullptr;
- data.present_value = nullptr;
+ if (nullptr != present) {
+ data.present = reinterpret_cast(present->MutableData());
+ }
data.fused_runner = reinterpret_cast(fused_runner);
- data.fused_cross_attention_kernel = nullptr;
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
- data.cumulated_sequence_length_q_cache = nullptr;
- data.cumulated_sequence_length_kv_cache = nullptr;
return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data);
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu b/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu
deleted file mode 100644
index 5d9cfcc69773a..0000000000000
--- a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu
+++ /dev/null
@@ -1,249 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "core/providers/cuda/cuda_common.h"
-#include "contrib_ops/cuda/bert/attention_impl.h"
-
-using namespace onnxruntime::cuda;
-
-namespace onnxruntime {
-namespace contrib {
-namespace cuda {
-
-template
-__global__ void ConcatTensorToTensor(const int tensor_add_sequence_length,
- const T* tensor_in,
- const T* tensor_add,
- T* tensor_out) {
- const int h = threadIdx.x;
- const int n = threadIdx.y;
- const int s = blockIdx.x;
- const int b = blockIdx.y;
- const int chunk_id = blockIdx.z;
-
- const int all_sequence_length = gridDim.x;
- const int batch_size = gridDim.y;
- const int num_heads = blockDim.y;
- const int H = blockDim.x;
-
- // K: number of identical tensors
- // tensor_in: K x BxNxPxH
- // tensor_add: K x BxNxLxH
- // tensor_out: K x BxNxTxH, where T = P + L
- const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length;
-
- const int present_SH = all_sequence_length * H;
- const int present_NSH = num_heads * present_SH;
- int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size);
- if (s < tensor_in_sequence_length) {
- const int past_SH = tensor_in_sequence_length * H;
- const int past_NSH = num_heads * past_SH;
- const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size);
- tensor_out[out_offset] = tensor_in[in_offset];
- } else if (s < all_sequence_length) {
- const int SH = tensor_add_sequence_length * H;
- const int NSH = num_heads * SH;
- const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size);
- tensor_out[out_offset] = tensor_add[in_offset];
- }
-}
-
-template
-__global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length,
- const int H,
- const T* tensor_in,
- const T* tensor_add,
- T* tensor_out) {
- // Use when (H*)*num_heads > 1024
- int h = threadIdx.x;
- const int n = threadIdx.y;
- const int s = blockIdx.x;
- const int b = blockIdx.y;
- const int chunk_id = blockIdx.z;
-
- const int all_sequence_length = gridDim.x;
- const int batch_size = gridDim.y;
- const int num_heads = blockDim.y;
- const int stride = blockDim.x;
-
- // K: number of identical tensor
- // tensor_in: K x BxNxPxH
- // tensor_add: K x BxNxLxH
- // tensor_out: K x BxNxTxH
- const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length;
-
- const int present_SH = all_sequence_length * H;
- const int present_NSH = num_heads * present_SH;
- while (h < H) {
- int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size);
- if (s < tensor_in_sequence_length) {
- const int past_SH = tensor_in_sequence_length * H;
- const int past_NSH = num_heads * past_SH;
- const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size);
- tensor_out[out_offset] = tensor_in[in_offset];
- } else if (s < all_sequence_length) {
- const int SH = tensor_add_sequence_length * H;
- const int NSH = num_heads * SH;
- const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size);
- tensor_out[out_offset] = tensor_add[in_offset];
- }
-
- h += stride;
- }
-}
-
-Status LaunchConcatTensorToTensor(cudaStream_t stream,
- const int all_sequence_length,
- const int sequence_length,
- const int batch_size,
- const int head_size,
- const int num_heads,
- const int max_threads_per_block,
- const int matrix_num,
- const float* tensor_in,
- const float* tensor_add,
- float* tensor_out) {
- const dim3 grid(all_sequence_length, batch_size, matrix_num);
- if (0 == (head_size & 1)) {
- const int H = head_size / 2;
- if (H * num_heads <= max_threads_per_block) {
- const dim3 block(H, num_heads, 1);
- ConcatTensorToTensor<<>>(sequence_length,
- reinterpret_cast(tensor_in),
- reinterpret_cast(tensor_add),
- reinterpret_cast(tensor_out));
- } else {
- const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
- ConcatTensorToTensorLarge<<>>(sequence_length,
- H,
- reinterpret_cast(tensor_in),
- reinterpret_cast(tensor_add),
- reinterpret_cast(tensor_out));
- }
- } else {
- if (head_size * num_heads <= max_threads_per_block) {
- const dim3 block(head_size, num_heads, 1);
- ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out);
- } else {
- const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
- ConcatTensorToTensorLarge<<>>(sequence_length,
- head_size,
- tensor_in,
- tensor_add,
- tensor_out);
- }
- }
- return CUDA_CALL(cudaGetLastError());
-}
-
-Status LaunchConcatTensorToTensor(cudaStream_t stream,
- const int all_sequence_length,
- const int sequence_length,
- const int batch_size,
- const int head_size,
- const int num_heads,
- const int max_threads_per_block,
- const int matrix_num,
- const half* tensor_in,
- const half* tensor_add,
- half* tensor_out) {
- const dim3 grid(all_sequence_length, batch_size, matrix_num);
- if (0 == (head_size % 4)) {
- const int H = head_size / 4;
- if (H * num_heads <= max_threads_per_block) {
- const dim3 block(H, num_heads, 1);
- ConcatTensorToTensor<<>>(sequence_length,
- reinterpret_cast(tensor_in),
- reinterpret_cast(tensor_add),
- reinterpret_cast(tensor_out));
- } else {
- const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
- ConcatTensorToTensorLarge<<>>(sequence_length,
- H,
- reinterpret_cast(tensor_in),
- reinterpret_cast(tensor_add),
- reinterpret_cast(tensor_out));
- }
- } else if (0 == (head_size & 1)) {
- const int H = head_size / 2;
- if (H * num_heads <= max_threads_per_block) {
- const dim3 block(H, num_heads, 1);
- ConcatTensorToTensor<<>>(sequence_length,
- reinterpret_cast(tensor_in),
- reinterpret_cast(tensor_add),
- reinterpret_cast(tensor_out));
- } else {
- const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
- ConcatTensorToTensorLarge<<>>(sequence_length,
- H,
- reinterpret_cast(tensor_in),
- reinterpret_cast(tensor_add),
- reinterpret_cast(tensor_out));
- }
- } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
- if (head_size * num_heads <= max_threads_per_block) {
- const dim3 block(head_size, num_heads, 1);
- ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out);
- } else {
- const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
- ConcatTensorToTensorLarge<<>>(sequence_length,
- head_size,
- tensor_in,
- tensor_add,
- tensor_out);
- }
- }
- return CUDA_CALL(cudaGetLastError());
-}
-
-Status LaunchConcatPastToPresent(cudaStream_t stream,
- const int all_sequence_length,
- const int sequence_length,
- const int batch_size,
- const int head_size,
- const int num_heads,
- const int max_threads_per_block,
- const float* past,
- const float* k_v,
- float* present) {
- return LaunchConcatTensorToTensor(
- stream,
- all_sequence_length,
- sequence_length,
- batch_size,
- head_size,
- num_heads,
- max_threads_per_block,
- 2,
- past,
- k_v,
- present);
-}
-
-Status LaunchConcatPastToPresent(cudaStream_t stream,
- const int all_sequence_length,
- const int sequence_length,
- const int batch_size,
- const int head_size,
- const int num_heads,
- const int max_threads_per_block,
- const half* past,
- const half* k_v,
- half* present) {
- return LaunchConcatTensorToTensor(
- stream,
- all_sequence_length,
- sequence_length,
- batch_size,
- head_size,
- num_heads,
- max_threads_per_block,
- 2,
- past,
- k_v,
- present);
-}
-
-} // namespace cuda
-} // namespace contrib
-} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index ae7696eb9fe0f..b4a4ae208ceb1 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -26,16 +26,11 @@ limitations under the License.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include
-#include
-#include
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
-#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/attention_softmax.h"
#include "contrib_ops/cuda/bert/transformer_common.h"
-#include "contrib_ops/cuda/bert/add_bias_transpose.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h"
#include "contrib_ops/cpu/bert/attention_base.h"
@@ -43,6 +38,7 @@ limitations under the License.
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
+#include "contrib_ops/cuda/bert/attention_impl.h"
using namespace onnxruntime::cuda;
using namespace onnxruntime::contrib::attention_softmax_cuda;
@@ -157,918 +153,285 @@ size_t GetAttentionWorkspaceSize(
}
template
-__global__ void AddBiasTransAppendKvToPresentSmall(
- const T* qkv, const T* biases, T* present,
- const int head_size, const int past_sequence_length, const int max_sequence_length) {
- // Input: BxSxMxNxH (Format 1)
- // Output: (2, B, N, [P..P+S) of MaxS, H),
- // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
- const int n = threadIdx.y;
- const int s = blockIdx.x;
- const int b = blockIdx.y;
- const int N = blockDim.y;
- const int S = gridDim.x;
- const int B = gridDim.y;
-
- constexpr int M = 3; // Matrix count in qkv
- const int m = blockIdx.z + 1; // k = 1, v = 2
-
- const int NH = N * head_size;
- const int NHS = NH * S;
-
- qkv += (n * head_size + (s * M + m) * NH + b * M * NHS);
- if (biases) {
- biases += (m * NH + n * head_size);
- }
+Status FusedTrtCrossAttention(
+ cudaStream_t stream,
+ contrib::AttentionParameters& parameters,
+ AttentionData& data) {
+ assert(data.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H);
- const int MsH = max_sequence_length * head_size;
- const int NMsH = N * MsH;
- const int BNMsH = B * NMsH;
- present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH);
+ // We only enable fused cross attention when there is no key padding mask.
+ // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
+ assert(data.mask_index == nullptr);
- for (int h = threadIdx.x; h < head_size; h += blockDim.x) {
- T bias = (biases ? biases[h] : (T)0.0f);
- present[h] = qkv[h] + bias;
- }
-}
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache,
+ data.mask_index, batch_size,
+ sequence_length, stream,
+ data.scratch);
-template
-__global__ void AddBiasTransAppendKvToPresent(
- const T* qkv, const T* biases, T* present,
- const int head_size, const int past_sequence_length, const int max_sequence_length) {
- // Input: BxSxMxNxH (Format 1)
- // Output: (2, B, N, [P..P+S) of MaxS, H),
- // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
- const int n = blockIdx.x;
- const int s = blockIdx.y;
- const int b = (blockIdx.z >> 1);
- const int N = gridDim.x;
- const int S = gridDim.y;
- const int B = (gridDim.z >> 1);
-
- constexpr int M = 3; // Matrix count in qkv
- const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2
-
- const int NH = N * head_size;
- const int NHS = NH * S;
-
- qkv += (n * head_size + (s * M + m) * NH + b * M * NHS);
- if (biases) {
- biases += (m * NH + n * head_size);
- }
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1);
- const int MsH = max_sequence_length * head_size;
- const int NMsH = N * MsH;
- const int BNMsH = B * NMsH;
- present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH);
+ int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int));
+ kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache,
+ data.mask_index, batch_size, parameters.kv_sequence_length, stream,
+ kv_sequence_offset);
+ CUDA_RETURN_IF_ERROR(cudaGetLastError());
- for (int h = threadIdx.x; h < head_size; h += blockDim.x) {
- T bias = (biases ? biases[h] : (T)0.0f);
- present[h] = qkv[h] + bias;
- }
+ DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1);
+
+ FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel =
+ reinterpret_cast(data.fused_cross_attention_kernel);
+
+ // When there is no bias, we can directly use q and packed kv from inputs.
+ void const* query = data.q;
+ void const* packed_kv = data.k;
+ if (data.value == nullptr && data.bias == nullptr) {
+ query = data.query;
+ packed_kv = data.key;
+ }
+
+ run_fused_cross_attention(
+ query, // Q
+ packed_kv, // packed KV
+ q_sequence_offset, // cumulated sequence length of Q
+ kv_sequence_offset, // cumulated sequence length of KV
+ data.output, // output
+ cross_attention_kernel, // kernels
+ batch_size, // batch size
+ parameters.num_heads, // number of heads
+ parameters.head_size, // head size of Q/K/V
+ sequence_length, // sequence length of Q
+ parameters.kv_sequence_length, // sequence length of KV
+ stream);
+
+ DUMP_TENSOR("trt cross output", data.output,
+ batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
+ return Status::OK();
}
-// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3.
-// bias is of shape (3, NxH) or nullptr
-// append to present of (2, B, N, (P..T) of M, H),
-template
-Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
- const int max_sequence_length,
- const int past_sequence_length,
- const int sequence_length,
- const int batch_size,
- const int head_size,
- const int num_heads,
- const int max_threads_per_block,
- const T* biases,
- const T* qkv_buffer,
- T* present) {
- assert(head_size <= (1 << 30));
-
- int64_t nh = (int64_t)head_size * num_heads;
- if (nh <= max_threads_per_block) {
- const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v
- const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
-
- AddBiasTransAppendKvToPresentSmall<<>>(
- qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length);
- } else {
- const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v
- const dim3 block(std::min(head_size, max_threads_per_block), 1, 1);
- AddBiasTransAppendKvToPresent<<>>(
- qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length);
- }
-
- return CUDA_CALL(cudaGetLastError());
+template <>
+Status FusedTrtCrossAttention(
+ cudaStream_t stream,
+ contrib::AttentionParameters& parameters,
+ AttentionData& data) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
+ "Trt fused cross attention does not support float tensor");
}
-template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
- const int max_sequence_length,
- const int total_sequence_length,
- const int sequence_length,
- const int batch_size,
- const int head_size,
- const int num_heads,
- const int max_threads_per_block,
- const float* bias,
- const float* qkv_buffer,
- float* present);
-
-template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
- const int max_sequence_length,
- const int total_sequence_length,
- const int sequence_length,
- const int batch_size,
- const int head_size,
- const int num_heads,
- const int max_threads_per_block,
- const half* bias,
- const half* qkv_buffer,
- half* present);
-
template
-Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
- AttentionData& data,
- cudaStream_t stream,
- int max_threads_per_block,
- AttentionQkvFormat& qkv_format) {
+Status FusedTrtSelfAttention(
+ cudaStream_t stream,
+ contrib::AttentionParameters& parameters,
+ AttentionData& data) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
- const int num_heads = parameters.num_heads;
- const int qk_head_size = parameters.head_size;
- const int v_head_size = parameters.v_head_size;
- const bool past_present_share_buffer = parameters.past_present_share_buffer;
- void* fused_runner = data.fused_runner;
- bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention;
+ const bool causal = parameters.is_unidirectional;
- T* qkv = data.workspace;
+ int* sequence_offset = reinterpret_cast(data.scratch);
- bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
- bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
-
- if (data.bias == nullptr) {
- assert(nullptr == fused_runner);
- // For quantized attention, bias has been added so only need transpose here.
- // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
- assert(qk_head_size == v_head_size);
- int matrix_to_trans = (past_present_share_buffer ? 1 : 3);
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.gemm_buffer, qkv, 3));
- qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ DUMP_TENSOR_INIT();
+ if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
+ DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length);
+ LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
} else {
- // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
- // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
- // For unfused kernel, transpose to 3xBxNxSxH (format 1)
- // For fused causal kernel, use format 1 since we need have K and V to update present state,
- // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
- const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
- qkv_format = use_fused_kernel
- ? AttentionQkvFormat::QKV_BSN3H
- : (use_flash_or_efficient_attention
- ? AttentionQkvFormat::Q_K_V_BSNH
- : (use_fused_causal
- ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
- : AttentionQkvFormat::Q_K_V_BNSH));
-
- // For fused causal, we will update gemm_buffer with bias directly.
- T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
-
- int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
- // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
- // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
- LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
- batch_size, sequence_length, num_heads, qk_head_size,
- data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
- 3, parameters.do_rotary, parameters.past_sequence_length);
+ sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache,
+ data.mask_index, batch_size, sequence_length, stream,
+ sequence_offset);
}
- return Status::OK();
-}
-
-// For MultiHeadAttention with past state
-template
-Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters,
- AttentionData& data,
- cudaStream_t stream,
- int max_threads_per_block,
- T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
- const int batch_size = parameters.batch_size;
- const int sequence_length = parameters.sequence_length;
- const int kv_sequence_length = parameters.kv_sequence_length;
- const int num_heads = parameters.num_heads;
- const int qk_head_size = parameters.head_size;
- const int v_head_size = parameters.v_head_size;
-
- DUMP_TENSOR_INIT();
+ DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1);
+ CUDA_RETURN_IF_ERROR(cudaGetLastError());
- if (data.bias == nullptr) {
- // Below logic does not support fused attention with past without bias
- // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past.
-
- // cross attention with past state
- if (data.past_key != nullptr && data.present_key == nullptr) {
- assert(data.past_value != nullptr);
- assert(data.query != nullptr);
- assert(data.key == nullptr);
- assert(data.value == nullptr);
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.query, q));
- }
- // cross attention with present state or self attention with present state
- else if (data.past_key == nullptr && data.present_key != nullptr) {
- assert(data.past_value == nullptr);
- assert(data.present_value != nullptr);
- assert(data.query != nullptr);
- assert(data.key != nullptr);
- assert(data.value != nullptr);
-
- // TODO: supporting packed qkv for self attention may benefit performance
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.query, q));
-
- // TODO: supporting packed kv for cross attention may benefit performance
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.key, data.present_key));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.value, data.present_value));
- }
- // self attention with past and present state
- else {
- assert(data.past_key != nullptr);
- assert(data.past_value != nullptr);
- assert(data.present_key != nullptr);
- assert(data.present_value != nullptr);
- assert(data.query != nullptr);
- assert(data.key != nullptr);
- assert(data.value != nullptr);
- // TODO: supporting packed qkv for self attention may benefit performance
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.query, q));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.key, k));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.value, v));
- }
- qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
- }
-#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
- // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value
- else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
- data.past_key != nullptr &&
- data.past_value != nullptr &&
- parameters.pass_past_in_kv) {
- // Transpose past_key and past_value to use memory efficient attention
-
- // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH)
- ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.past_key, data.temp_k_workspace));
- // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
- ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.past_value, data.temp_v_workspace));
-
- // query => q, temp_k_workspace => k, temp_v_workspace => v
- LaunchAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length,
- num_heads, qk_head_size, v_head_size,
- data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);
-
- DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
-
- data.past_key = nullptr;
- data.past_value = nullptr;
- }
- // When there is no past_key/past_value and there is present_key/present_value
- // (e.g. get initial kv to use as past_kv in the next iteration)
- else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
- data.present_key != nullptr &&
- data.present_value != nullptr) {
- // Use memory efficient attention kernel
- LaunchAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length,
- num_heads, qk_head_size, v_head_size,
- data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace);
-
- // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH)
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.temp_k_workspace, data.present_key));
-
- // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v)
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.temp_v_workspace, data.present_value));
-
- DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
- }
-#endif
- else {
- // Use unfused kernel for Q, use unfused kernel for K and V if needed
- constexpr int format = 0;
- // Query (BxSxNxH) => Q (BxNxSxH)
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, sequence_length, num_heads, qk_head_size,
- data.query, data.bias, q,
- true, -1);
-
- if (!parameters.pass_past_in_kv) {
- T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k;
- T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v;
-
- // Key (BxLxNxH) => K (BxNxLxH)
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, qk_head_size,
- data.key, data.bias + num_heads * qk_head_size, k_dest,
- true, -1);
-
- // Value (BxLxNxH_v) => V (BxNxLxH_v)
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, v_head_size,
- data.value, data.bias + 2 * num_heads * qk_head_size, v_dest,
- true, -1);
-
- DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
- DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size);
- DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size);
- }
- qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
- }
- return Status::OK();
-}
+ FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner);
-// For MultiHeadAttention without past state, with packed QKV inputs
-template
-Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
- AttentionData& data,
- cudaStream_t stream,
- int max_threads_per_block,
- T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
- const int batch_size = parameters.batch_size;
- const int sequence_length = parameters.sequence_length;
- const int num_heads = parameters.num_heads;
- const int qk_head_size = parameters.head_size;
- const int v_head_size = parameters.v_head_size;
- void* fused_runner = data.fused_runner;
+ const int S = causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
- T* qkv = data.workspace;
+ // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed.
+ const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size);
- bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
+ fused_fp16_runner->setup(S, B);
- assert(data.bias == nullptr);
- assert(qk_head_size == v_head_size);
+ if (!causal) {
+ assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H);
- DUMP_TENSOR_INIT();
- DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size);
-
- if (data.use_memory_efficient_attention || data.use_flash_attention) {
- // unpack qkv to BSNH. Note that there is no bias so we need not output query to q.
- constexpr int format = 4;
- T* qkv_add_bias = nullptr;
- LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block,
- batch_size, sequence_length, num_heads, qk_head_size,
- data.query, data.bias, qkv,
- true, v_head_size, qkv_add_bias, 3);
- DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
- } else {
- if (!use_fused_kernel) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, NOT_IMPLEMENTED,
- "packed QKV format is not implemented for current GPU. Please disable it in fusion options.");
+ // When there is no bias, we can directly use packed qkv from inputs.
+ void const* packed_qkv = data.q;
+ if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) {
+ packed_qkv = data.query;
}
- qkv_format = AttentionQkvFormat::QKV_BSN3H;
+ fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream);
+ DUMP_TENSOR("fused output", data.output,
+ batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
+ } else {
+ assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
+ fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream);
+ DUMP_TENSOR("fused causal output", data.output,
+ batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
}
return Status::OK();
}
-// For MultiHeadAttention without past state, with packed KV inputs
+// Template Specialization for float type
+template <>
+Status FusedTrtSelfAttention(
+ cudaStream_t stream,
+ contrib::AttentionParameters& parameters,
+ AttentionData& data) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
+ "Trt fused attention does not support float tensor");
+}
+
+#if USE_FLASH_ATTENTION
template
-Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
- AttentionData& data,
- cudaStream_t stream,
- int max_threads_per_block,
- T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
- const int batch_size = parameters.batch_size;
- const int kv_sequence_length = parameters.kv_sequence_length;
- const int num_heads = parameters.num_heads;
- const int qk_head_size = parameters.head_size;
- const int v_head_size = parameters.v_head_size;
+Status FlashAttention(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ float scale) {
+ assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ assert(nullptr == data.mask_index);
+ assert(nullptr == data.relative_position_bias);
+ assert(parameters.head_size == parameters.v_head_size);
- // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint.
- // CheckInputs verified this constraint.
- assert(data.bias == nullptr);
- assert(qk_head_size == v_head_size);
+ void* query = reinterpret_cast(data.q);
+ void* key = reinterpret_cast(data.k);
+ void* value = reinterpret_cast(data.v);
+ // For packed KV, we can use query input directly.
+ if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) {
+ query = reinterpret_cast(const_cast(data.query));
+ }
DUMP_TENSOR_INIT();
- DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size);
-
- if (data.use_memory_efficient_attention || data.use_flash_attention) {
- // unpack kv to BSNH. Note that there is no bias so we need not output query to q.
- constexpr int format = 4;
- T* qkv_add_bias = nullptr;
- const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size);
- LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, qk_head_size,
- data.key, kv_bias, k,
- true, v_head_size, qkv_add_bias, 2);
- DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
- } else {
- if (data.fused_cross_attention_kernel == nullptr) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, NOT_IMPLEMENTED,
- "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options.");
- }
+ DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query),
+ parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size);
+ DUMP_TENSOR_D("k(BSNH)", data.k,
+ parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size);
+ DUMP_TENSOR_D("v(BSNH)", data.v,
+ parameters.batch_size, parameters.total_sequence_length,
+ parameters.num_heads, parameters.v_head_size);
+
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
+ device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch),
+ parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size,
+ parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional));
+
+ DUMP_TENSOR("flash attention output", data.output,
+ parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size);
- qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
- }
return Status::OK();
}
-// For MultiHeadAttention without past state, with Q, K and V inputs
-template
-Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters,
- AttentionData& data,
- cudaStream_t stream,
- int max_threads_per_block,
- T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
- const int batch_size = parameters.batch_size;
- const int sequence_length = parameters.sequence_length;
- const int kv_sequence_length = parameters.kv_sequence_length;
- const int num_heads = parameters.num_heads;
- const int qk_head_size = parameters.head_size;
- const int v_head_size = parameters.v_head_size;
- void* fused_runner = data.fused_runner;
-
- T* qkv = data.workspace;
-
- bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
- bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
-
- // gemm_buffer == nullptr and not packed
- assert(data.query != nullptr && data.key != nullptr && data.value != nullptr);
-
- DUMP_TENSOR_INIT();
- DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size);
-
-#if DUMP_TENSOR_LEVEL > 1
- if (data.bias != nullptr) {
- DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size);
- DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size);
- DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
- }
+template <>
+Status FlashAttention(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ float scale) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor");
+}
#endif
- if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) {
- DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias,
- num_heads, sequence_length, kv_sequence_length);
- }
-
- if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) {
- DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1);
- }
-
- if (data.fused_cross_attention_kernel != nullptr) {
- assert(qk_head_size == v_head_size);
-
- // For fused cross attention, besides adding bias, K and V needed to be packed:
- // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH
- LaunchAddBiasTransposeTrt(
- stream, max_threads_per_block,
- batch_size, sequence_length,
- num_heads, qk_head_size,
- data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length);
-
- qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
- }
-#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
- else if (data.use_memory_efficient_attention || data.use_flash_attention) {
- LaunchAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length,
- num_heads, qk_head_size, v_head_size,
- data.bias, data.query, data.key, data.value, q, k, v);
-
- DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
- }
-#endif
- else if (use_fused_kernel) {
- assert(qk_head_size == v_head_size);
+#if USE_MEMORY_EFFICIENT_ATTENTION
+template
+Status EfficientAttention(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ float scale) {
+ // We only enable fused cross attention when there is no key padding mask.
+ // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
+ assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
- // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H)
- LaunchAddBiasTransposeTrt(
- stream, max_threads_per_block,
- batch_size, sequence_length,
- num_heads, qk_head_size,
- data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length);
- DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size);
-
- qkv_format = AttentionQkvFormat::QKV_BSN3H;
- } else { // unfused kernel
- ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal");
-
- // Query (BxSxNxH) => Q (BxNxSxH)
- constexpr int format = 0;
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, sequence_length, num_heads, qk_head_size,
- data.query, data.bias, q,
- true, -1);
-
- // Key (BxLxNxH) => K (BxNxLxH)
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, qk_head_size,
- data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k,
- true, -1);
-
- // Value (BxLxNxH_v) => K (BxNxLxH_v)
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, v_head_size,
- data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v,
- true, -1);
-
- DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
- DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size);
- DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ const void* query = data.q;
+ const void* key = data.k;
+ const void* value = data.v;
+ // For packed KV, we can use query input directly.
+ if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) {
+ assert(data.bias == nullptr);
+ query = data.query;
}
- return Status::OK();
-}
-template
-Status PrepareQkv(contrib::AttentionParameters& parameters,
- AttentionData& data,
- cudaStream_t stream,
- int max_threads_per_block,
- T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
- if (nullptr != data.gemm_buffer) { // Attention operator
- ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, qkv_format));
- } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state
- ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
- } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv
- ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
- } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv
- ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
- } else { // multihead attention operator, no past, separated Q/K/V inputs
- ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
- }
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query),
+ parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size);
+ DUMP_TENSOR_D("k(BSNH)", data.k,
+ parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size);
+ DUMP_TENSOR_D("v(BSNH)", data.v,
+ parameters.batch_size, parameters.total_sequence_length,
+ parameters.num_heads, parameters.v_head_size);
+
+ MemoryEfficientAttentionParams p;
+ p.sm = device_prop.major * 10 + device_prop.minor;
+ p.is_half = sizeof(T) == 2;
+ p.batch_size = parameters.batch_size;
+ p.num_heads = parameters.num_heads;
+ p.sequence_length = parameters.sequence_length;
+ p.kv_sequence_length = parameters.total_sequence_length;
+ p.qk_head_size = parameters.head_size;
+ p.v_head_size = parameters.v_head_size;
+ p.causal = parameters.is_unidirectional;
+ p.scale = scale;
+ p.seqlen_k_ptr = nullptr == data.mask_index
+ ? nullptr
+ : const_cast(reinterpret_cast(data.mask_index));
+ p.seqstart_q_ptr = nullptr == data.mask_index
+ ? nullptr
+ : const_cast(reinterpret_cast(
+ data.mask_index + parameters.batch_size));
+ p.seqstart_k_ptr = nullptr == data.mask_index
+ ? nullptr
+ : const_cast(reinterpret_cast(
+ data.mask_index + 2 * parameters.batch_size + 1));
+ p.query = query;
+ p.key = key;
+ p.value = value;
+ p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
+ p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
+ p.output = data.output;
+ p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float))
+ ? data.scratch
+ : nullptr;
+ p.stream = stream;
+ run_memory_efficient_attention(p);
+ DUMP_TENSOR("efficient attention output", data.output,
+ parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size);
- CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
+#endif
template
-Status QkvToContext(
+Status UnfusedAttention(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
- AttentionData& data) {
+ AttentionData& data,
+ float scale) {
+ assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
+
auto stream = static_cast(ort_stream->GetHandle());
- constexpr size_t element_size = sizeof(T);
- const int max_threads_per_block = device_prop.maxThreadsPerBlock;
+
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
- const int kv_sequence_length = parameters.kv_sequence_length;
const int total_sequence_length = parameters.total_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
- const bool past_present_share_buffer = parameters.past_present_share_buffer;
- const float mask_filter_value = parameters.mask_filter_value;
- void* fused_runner = data.fused_runner;
-
- // At most one fused kernel is enabled.
- assert((int(data.use_flash_attention) +
- int(data.use_memory_efficient_attention) +
- int(fused_runner != nullptr) +
- int(data.fused_cross_attention_kernel != nullptr)) <= 1);
-
const int batches = batch_size * num_heads;
- T* qkv = nullptr;
- T* q = nullptr;
- T* k = nullptr;
- T* v = nullptr;
- T* scratch1 = data.workspace;
- if (data.has_qkv_workspace) {
- const int size_per_batch_q = sequence_length * qk_head_size;
- const int size_per_batch_k = kv_sequence_length * qk_head_size;
- const int size_per_batch_v = kv_sequence_length * v_head_size;
- const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q);
- const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k);
- const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v);
- qkv = data.workspace;
- q = qkv;
- k = q + elements_q;
- v = k + elements_k;
- scratch1 = v + elements_v;
- }
-
- bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
- bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
-
- AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
- ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
-
- int present_size_per_batch_k = 0;
- int present_size_per_batch_v = 0;
- if (!past_present_share_buffer) {
- // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length.
- // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH)
- // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH)
- // When there is past state, the head size for Q/K/V shall be same: H == H_v.
- present_size_per_batch_k = total_sequence_length * qk_head_size;
- present_size_per_batch_v = total_sequence_length * v_head_size;
-
- if (nullptr != data.present) {
- assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH || qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
- ORT_RETURN_IF_ERROR(
- LaunchConcatPastToPresent(
- stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, data.past, k, data.present));
-
- // Update pointers to present_k and present_v.
- k = data.present;
- v = data.present + batches * present_size_per_batch_k;
- }
-
- if (nullptr != data.past_key || nullptr != data.present_key) {
- if (nullptr != data.past_key && nullptr == data.present_key) {
- k = const_cast(data.past_key);
- v = const_cast(data.past_value);
- } else if (nullptr == data.past_key && nullptr != data.present_key) {
- if (qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
- k = data.present_key;
- v = data.present_value;
- } else {
- assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
- k = data.temp_k_workspace;
- v = data.temp_v_workspace;
- }
- } else if (parameters.pass_past_in_kv) {
- // past_key and past_value are used directly as key and value in attention computations
- k = const_cast(data.past_key);
- v = const_cast(data.past_value);
-
- // This path has a memory copy from past_key and past_value to present_key and present_value
- // Avoid this path since the memory copy is unnecessary because past_key == present_key and
- // past_value == present_value
- int64_t k_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * qk_head_size;
- int64_t v_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * v_head_size;
- cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
- cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
- } else {
- ORT_RETURN_IF_ERROR(
- LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length,
- batch_size, qk_head_size, num_heads,
- max_threads_per_block, 1, data.past_key, k, data.present_key));
- ORT_RETURN_IF_ERROR(
- LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length,
- batch_size, v_head_size, num_heads,
- max_threads_per_block, 1, data.past_value, v, data.present_value));
- // Update pointers to present_k and present_v.
- k = data.present_key;
- v = data.present_value;
- }
- }
- } else { // past_present_share_buffer
- assert(qk_head_size == v_head_size);
- assert(data.fused_cross_attention_kernel == nullptr);
- assert(!use_fused_kernel);
- assert(data.gemm_buffer != nullptr);
- assert(!data.use_memory_efficient_attention);
- assert(!data.use_flash_attention);
- assert(data.has_qkv_workspace);
-
- if (nullptr != data.past_key || nullptr != data.present_key) {
- // TODO: support this case.
- ORT_THROW("buffer sharing for no bias case between past and present is not supported yet.");
- }
-
- if (data.present != data.past) {
- // For easy testing. Production should better avoid this path.
- int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size;
- cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
- }
-
- // append last k v to present
- ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent(
- stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length,
- batch_size, qk_head_size, num_heads, max_threads_per_block,
- use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer
- data.gemm_buffer, data.present));
-
- present_size_per_batch_k = parameters.max_sequence_length * qk_head_size;
- present_size_per_batch_v = present_size_per_batch_k;
- k = data.present;
- v = data.present + batches * present_size_per_batch_k;
- }
-
- // Q, K and V are ready now
- DUMP_TENSOR_INIT();
-
- if (data.fused_cross_attention_kernel != nullptr) {
- assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H);
-
- // We only enable fused cross attention when there is no key padding mask.
- // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
- assert(data.mask_index == nullptr);
-
- int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache,
- data.mask_index, batch_size, sequence_length, stream,
- scratch1);
-
- DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1);
-
- int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int));
- kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache,
- data.mask_index, batch_size, kv_sequence_length, stream,
- kv_sequence_offset);
- CUDA_RETURN_IF_ERROR(cudaGetLastError());
-
- DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1);
-
- FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel =
- reinterpret_cast(data.fused_cross_attention_kernel);
-
- // When there is no bias, we can directly use q and packed kv from inputs.
- void const* query = q;
- void const* packed_kv = k;
- if (data.value == nullptr && data.bias == nullptr) {
- query = data.query;
- packed_kv = data.key;
- }
-
- run_fused_cross_attention(
- query, // Q
- packed_kv, // packed KV
- q_sequence_offset, // cumulated sequence length of Q
- kv_sequence_offset, // cumulated sequence length of KV
- data.output, // output
- cross_attention_kernel, // kernels
- batch_size, // batch size
- num_heads, // number of heads
- qk_head_size, // head size of Q/K/V
- sequence_length, // sequence length of Q
- kv_sequence_length, // sequence length of KV
- stream);
-
- DUMP_TENSOR("trt cross output", data.output, batch_size, sequence_length, num_heads, v_head_size);
- return Status::OK();
- }
-
- // Run TRT fused attention.
- if (use_fused_kernel || use_fused_causal) {
- int* sequence_offset = reinterpret_cast(scratch1);
- if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
- DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length);
- LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
- } else {
- sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache,
- data.mask_index, batch_size, sequence_length, stream,
- sequence_offset);
- }
- DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1);
- CUDA_RETURN_IF_ERROR(cudaGetLastError());
-
- FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(fused_runner);
-
- const int S = use_fused_causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
-
- // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed.
- const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size);
-
- fused_fp16_runner->setup(S, B);
-
- if (use_fused_kernel) {
- assert(qkv_format == AttentionQkvFormat::QKV_BSN3H);
-
- // When there is no bias, we can directly use packed qkv from inputs.
- void const* packed_qkv = qkv;
- if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) {
- packed_qkv = data.query;
- }
-
- fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream);
- DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, num_heads, v_head_size);
- } else {
- assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
- fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream);
- DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, num_heads, v_head_size);
- }
- return Status::OK();
- }
-
- // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
- const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size))
- : parameters.scale;
-
-#if USE_FLASH_ATTENTION
- if (data.use_flash_attention) {
- assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
- assert(nullptr == data.mask_index);
- assert(nullptr == data.relative_position_bias);
- assert(parameters.head_size == parameters.v_head_size);
-
- void* query = reinterpret_cast(q);
- void* key = reinterpret_cast(k);
- void* value = reinterpret_cast(v);
- // For packed KV, we can use query input directly.
- if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) {
- query = reinterpret_cast(const_cast(data.query));
- }
-
- DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
-
- constexpr bool is_causal = false;
- ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
- device_prop, stream, query, key, value, data.output, reinterpret_cast(scratch1),
- parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size,
- parameters.sequence_length, parameters.total_sequence_length, scale, is_causal));
-
- DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, v_head_size);
-
- return Status::OK();
- }
-#endif
-
-#if USE_MEMORY_EFFICIENT_ATTENTION
- if (data.use_memory_efficient_attention) {
- // We only enable fused cross attention when there is no key padding mask.
- // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
- assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
-
- const void* query = q;
- const void* key = k;
- const void* value = v;
- // For packed KV, we can use query input directly.
- if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) {
- assert(data.bias == nullptr);
- query = data.query;
- }
-
- DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
-
- MemoryEfficientAttentionParams p;
- p.sm = device_prop.major * 10 + device_prop.minor;
- p.is_half = sizeof(T) == 2;
- p.batch_size = parameters.batch_size;
- p.num_heads = parameters.num_heads;
- p.sequence_length = parameters.sequence_length;
- p.kv_sequence_length = parameters.total_sequence_length;
- p.qk_head_size = parameters.head_size;
- p.v_head_size = parameters.v_head_size;
- p.causal = parameters.is_unidirectional;
- p.scale = scale;
- p.seqlen_k_ptr = nullptr == data.mask_index
- ? nullptr
- : const_cast(reinterpret_cast(data.mask_index));
- p.seqstart_q_ptr = nullptr == data.mask_index
- ? nullptr
- : const_cast(reinterpret_cast(data.mask_index + batch_size));
- p.seqstart_k_ptr = nullptr == data.mask_index
- ? nullptr
- : const_cast(reinterpret_cast(data.mask_index + 2 * batch_size + 1));
- p.query = query;
- p.key = key;
- p.value = value;
- p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
- p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
- p.output = data.output;
- p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float))
- ? scratch1
- : nullptr;
- p.stream = stream;
- run_memory_efficient_attention(p);
- DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, v_head_size);
-
- return Status::OK();
- }
-#endif
-
- // The following are unfused attention.
- assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
const int* mask_index = data.mask_index;
gsl::span& mask_index_dims = data.mask_index_dims;
// Raw attention mask could be 2D (BxT) or 3D (BxSxT) or 4D(Bx1xMxM), where M is the max sequence length.
bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2);
- // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxT
+ // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch: BxNxSxT
// Q: BxNxSxH, K (present_k): BxNxTxH, Q*K': BxNxSxT
float one = 1.0f;
float zero = 0.f;
@@ -1077,22 +440,31 @@ Status QkvToContext(
cublasSetStream(cublas, stream);
- DUMP_TENSOR_D("q[BNSH]", q, batch_size, num_heads, sequence_length, qk_head_size);
- DUMP_TENSOR_D("k[BNSH]", k, batch_size, num_heads, total_sequence_length, qk_head_size);
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR_D("q[BNSH]", data.q, batch_size, num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("k[BNSH]", data.k, batch_size, num_heads, total_sequence_length, qk_head_size);
+
+ const int present_sequence_length = parameters.past_present_share_buffer
+ ? parameters.max_sequence_length
+ : total_sequence_length;
+ const int present_size_per_batch_k = present_sequence_length * qk_head_size;
+ const int present_size_per_batch_v = present_sequence_length * v_head_size;
+
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
total_sequence_length, sequence_length, qk_head_size,
- &alpha, k, qk_head_size, present_size_per_batch_k,
- q, qk_head_size, sequence_length * qk_head_size,
- &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop));
+ &alpha, data.k, qk_head_size, present_size_per_batch_k,
+ data.q, qk_head_size, sequence_length * qk_head_size,
+ &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop));
- DUMP_TENSOR_D("Q", q, batch_size, num_heads, sequence_length, qk_head_size);
- DUMP_TENSOR_D("K", k, batch_size, num_heads, qk_head_size, sequence_length);
- DUMP_TENSOR_D("QK", scratch1, batch_size, num_heads, sequence_length, total_sequence_length);
+ DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length);
+ DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length);
+ constexpr size_t element_size = sizeof(T);
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
sequence_length, total_sequence_length);
- T* scratch2 = scratch1 + (bytes / element_size);
+ T* scratch2 = data.scratch + (bytes / element_size);
// Apply softmax and store result R to scratch2: BxNxSxT
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
@@ -1102,14 +474,15 @@ Status QkvToContext(
const TransformerOptions* options = TransformerOptions::GetInstance();
bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax();
- T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax.
+ // replace Q*K' in place with masked score for persistent softmax.
+ T* persistent_softmax_workspace = data.scratch;
ORT_RETURN_IF_ERROR(
ComputeSoftmaxWithRawMask(
ort_stream, total_sequence_length, sequence_length, batch_size, num_heads,
mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias,
- scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension,
+ data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension,
parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace,
- mask_filter_value));
+ parameters.mask_filter_value));
} else if (nullptr != mask_index) { // 1d mask index
assert(mask_index_dims.size() == 1);
// mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions.
@@ -1117,277 +490,123 @@ Status QkvToContext(
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D(
stream, total_sequence_length, sequence_length, batch_size, num_heads,
mask_index, mask_start, data.relative_position_bias, parameters.broadcast_res_pos_bias,
- scratch1, scratch2, parameters.is_unidirectional));
+ data.scratch, scratch2, parameters.is_unidirectional));
} else { // no mask
ORT_RETURN_IF_ERROR(
ComputeSoftmax(
stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias,
- parameters.broadcast_res_pos_bias, scratch1, scratch2, parameters.is_unidirectional));
+ parameters.broadcast_res_pos_bias, data.scratch, scratch2, parameters.is_unidirectional));
}
DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);
- DUMP_TENSOR_D("V", v, batch_size, num_heads, sequence_length, v_head_size);
+ DUMP_TENSOR_D("V", data.v, batch_size, num_heads, sequence_length, v_head_size);
// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
- T* temp_output = qkv;
+ T* temp_output = data.q;
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
v_head_size, sequence_length, total_sequence_length,
- &one, v, v_head_size, present_size_per_batch_v,
+ &one, data.v, v_head_size, present_size_per_batch_v,
scratch2, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop));
// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, temp_output, data.output);
+ device_prop.maxThreadsPerBlock, false, temp_output, data.output);
DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size);
return result;
}
template
-Status DecoderQkvToContext(
+Status QkvToContext(
const cudaDeviceProp& device_prop,
- Stream* ort_stream,
cublasHandle_t& cublas,
- const size_t element_size,
- const int batch_size,
- const int sequence_length,
- const int kv_sequence_length,
- const int num_heads,
- const int head_size,
- const bool static_kv,
- const bool use_past,
- const bool has_layer_state,
- const bool has_key_padding_mask,
- const float mask_filter_value,
- const T* gemm_query_buffer,
- const T* gemm_kv_buffer,
- const bool* key_padding_mask,
- const T* key_cache,
- const T* value_cache,
- T* qkv_buffer,
- T* workspace_buffer,
- T* output,
- T* new_key_cache,
- T* new_value_cache) {
+ Stream* ort_stream,
+ contrib::AttentionParameters& parameters,
+ AttentionData& data) {
+ auto stream = static_cast(ort_stream->GetHandle());
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
- const int BN = batch_size * num_heads;
- const int BHN = BN * head_size;
- const int BNS = BN * sequence_length;
- const int k_buffer_offset = sequence_length * BHN;
- const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN;
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int total_sequence_length = parameters.total_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+ void* fused_runner = data.fused_runner;
- T* temp_qkv_buffer = workspace_buffer;
- auto stream = static_cast(ort_stream->GetHandle());
+ // At most one fused kernel is enabled.
+ assert((int(data.use_flash_attention) +
+ int(data.use_memory_efficient_attention) +
+ int(fused_runner != nullptr) +
+ int(data.fused_cross_attention_kernel != nullptr)) <= 1);
- const T* q = qkv_buffer;
- // transpose q and copy them to qkv_buffer
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads,
- max_threads_per_block, true, gemm_query_buffer, qkv_buffer));
-
- const T* k = qkv_buffer + k_buffer_offset;
- const T* v = qkv_buffer + v_buffer_offset;
- if (!has_layer_state || !use_past) {
- if (!static_kv) {
- // transpose kv and copy them to qkv_buffer
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads,
- max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
- } else {
- // transpose kv and copy them to qkv_buffer
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads,
- max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
- }
- } else {
- if (!static_kv) {
- // transpose kv and copy them to temp_buffer
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads,
- max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer));
- // concat cache-k with k and copy to qkv_buffer
- if (nullptr != key_cache) {
- ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length,
- sequence_length, batch_size, head_size, num_heads,
- max_threads_per_block, 1,
- key_cache,
- temp_qkv_buffer,
- qkv_buffer + k_buffer_offset));
- }
- // concat cache-v with v and copy to qkv_buffer
- if (nullptr != value_cache) {
- ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length,
- sequence_length, batch_size, head_size, num_heads,
- max_threads_per_block, 1,
- value_cache,
- temp_qkv_buffer + k_buffer_offset,
- qkv_buffer + v_buffer_offset));
- }
+ ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block));
+
+ if (!parameters.past_present_share_buffer) {
+ ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size,
+ sequence_length, total_sequence_length, parameters.pass_past_in_kv,
+ stream, max_threads_per_block, data));
+
+ } else { // past_present_share_buffer
+ assert(qk_head_size == v_head_size);
+ assert(data.fused_cross_attention_kernel == nullptr);
+ assert(nullptr == fused_runner || parameters.is_unidirectional);
+ assert(data.gemm_buffer != nullptr);
+ assert(!data.use_memory_efficient_attention);
+ assert(!data.use_flash_attention);
+ assert(data.has_qkv_workspace);
+
+ if (nullptr != data.past_key || nullptr != data.present_key) {
+ // TODO: support this case.
+ ORT_THROW("buffer sharing for no bias case between past and present is not supported yet.");
}
- }
- if (has_layer_state) {
- if (use_past && static_kv) {
- CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T),
- cudaMemcpyDeviceToDevice, stream));
- CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T),
- cudaMemcpyDeviceToDevice, stream));
- } else {
- CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T),
- cudaMemcpyDeviceToDevice, stream));
- CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T),
- cudaMemcpyDeviceToDevice, stream));
+ if (data.present != data.past) {
+ // For easy testing. Production should better avoid this path.
+ int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size;
+ cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
}
- }
- // scratch1: BxNxSxL buffer
- // scratch2: BxNxSxL buffer
- // scratch3: BxNxSxH buffer
- T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length;
- T* scratch2 = scratch1 + BNS * kv_sequence_length;
- T* scratch3 = scratch2 + BNS * kv_sequence_length;
-
- // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL
- // Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL
- const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size));
- const int temp_matrix_size = sequence_length * kv_sequence_length;
- float one = 1.0f;
- float zero = 0.f;
+ // For fused causal, bias has been added to gemm_buffer.
+ const T* bias = (nullptr != fused_runner && parameters.is_unidirectional) ? nullptr : data.bias;
- float alpha = rsqrt_head_size;
- const int strideA = kv_sequence_length * head_size;
- const int strideB = sequence_length * head_size;
- if (use_past && static_kv) {
- CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
- cublas, CUBLAS_OP_T, CUBLAS_OP_N,
- kv_sequence_length, sequence_length, head_size,
- &alpha, key_cache, head_size, strideA,
- q, head_size, strideB,
- &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
- } else {
- CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
- cublas, CUBLAS_OP_T, CUBLAS_OP_N,
- kv_sequence_length, sequence_length, head_size,
- &alpha, k, head_size, strideA,
- q, head_size, strideB,
- &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
+ // append last k v to present
+ ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent(
+ stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length,
+ batch_size, qk_head_size, num_heads, max_threads_per_block,
+ bias, data.gemm_buffer, data.present));
+
+ data.k = data.present;
+ data.v = data.present + batch_size * num_heads * parameters.max_sequence_length * qk_head_size;
}
- constexpr bool is_unidirectional = false;
- const T* add_before_softmax = nullptr;
- if (has_key_padding_mask) {
- constexpr int mask_dimension = 2;
- constexpr int max_sequence_length = 0;
- ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask(
- ort_stream, kv_sequence_length, sequence_length, batch_size,
- num_heads, nullptr, key_padding_mask, add_before_softmax,
- false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional,
- 1.0f, mask_dimension, max_sequence_length, false, nullptr,
- mask_filter_value));
- } else {
- ORT_RETURN_IF_ERROR(ComputeSoftmax(
- stream, kv_sequence_length, sequence_length, batch_size, num_heads,
- add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2,
- is_unidirectional));
+ // Q, K and V are ready now
+ if (data.fused_cross_attention_kernel != nullptr) {
+ return FusedTrtCrossAttention(stream, parameters, data);
}
- // compute P*V (as V*P), and store in scratch3: BxNxSxH
- if (use_past && static_kv) {
- CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
- cublas, CUBLAS_OP_N, CUBLAS_OP_N,
- head_size, sequence_length, kv_sequence_length,
- &one, value_cache, head_size, strideA,
- scratch2, kv_sequence_length, temp_matrix_size,
- &zero, scratch3, head_size, strideB, BN, device_prop));
- } else {
- CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
- cublas, CUBLAS_OP_N, CUBLAS_OP_N,
- head_size, sequence_length, kv_sequence_length,
- &one, v, head_size, strideA,
- scratch2, kv_sequence_length, temp_matrix_size,
- &zero, scratch3, head_size, strideB, BN, device_prop));
+ // Run TRT fused attention.
+ if (nullptr != fused_runner) {
+ return FusedTrtSelfAttention(stream, parameters, data);
}
- // scratch3 is BxNxSxH, transpose to output SxBxNxH
- return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
- max_threads_per_block, true, scratch3, output);
-}
+ // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
+ const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size))
+ : parameters.scale;
-Status LaunchDecoderAttentionKernel(
- const cudaDeviceProp& device_prop,
- Stream* stream,
- cublasHandle_t& cublas,
- const size_t element_size,
- const int batch_size,
- const int sequence_length,
- const int kv_sequence_length,
- const int num_heads,
- const int head_size,
- const bool static_kv,
- const bool use_past,
- const bool has_layer_state,
- const bool has_key_padding_mask,
- const float mask_filter_value,
- const void* gemm_query_buffer,
- const void* gemm_kv_buffer,
- const bool* key_padding_mask,
- const void* key_cache,
- const void* value_cache,
- void* qkv_buffer,
- void* workspace_buffer,
- void* output,
- void* new_key_cache,
- void* new_value_cache) {
- if (element_size == 2) {
- return DecoderQkvToContext(
- device_prop,
- stream,
- cublas,
- element_size,
- batch_size,
- sequence_length,
- kv_sequence_length,
- num_heads,
- head_size,
- static_kv,
- use_past,
- has_layer_state,
- has_key_padding_mask,
- mask_filter_value,
- reinterpret_cast(gemm_query_buffer),
- reinterpret_cast(gemm_kv_buffer),
- key_padding_mask,
- reinterpret_cast(key_cache),
- reinterpret_cast(value_cache),
- reinterpret_cast(qkv_buffer),
- reinterpret_cast(workspace_buffer),
- reinterpret_cast(output),
- reinterpret_cast(new_key_cache),
- reinterpret_cast(new_value_cache));
- } else {
- return DecoderQkvToContext(
- device_prop,
- stream,
- cublas,
- element_size,
- batch_size,
- sequence_length,
- kv_sequence_length,
- num_heads,
- head_size,
- static_kv,
- use_past,
- has_layer_state,
- has_key_padding_mask,
- mask_filter_value,
- reinterpret_cast(gemm_query_buffer),
- reinterpret_cast(gemm_kv_buffer),
- key_padding_mask,
- reinterpret_cast(key_cache),
- reinterpret_cast(value_cache),
- reinterpret_cast(qkv_buffer),
- reinterpret_cast(workspace_buffer),
- reinterpret_cast(output),
- reinterpret_cast(new_key_cache),
- reinterpret_cast(new_value_cache));
+#if USE_FLASH_ATTENTION
+ if (data.use_flash_attention) {
+ return FlashAttention(device_prop, stream, parameters, data, scale);
}
+#endif
+
+#if USE_MEMORY_EFFICIENT_ATTENTION
+ if (data.use_memory_efficient_attention) {
+ return EfficientAttention(device_prop, stream, parameters, data, scale);
+ }
+#endif
+
+ return UnfusedAttention(device_prop, cublas, ort_stream, parameters, data, scale);
}
// Template Instantiation
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
index af7373dd9fa1b..d0a5fb51a25d6 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
@@ -2,11 +2,12 @@
// Licensed under the MIT License.
#pragma once
-#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
#include
#include
-#include "contrib_ops/cpu/bert/attention_common.h"
+#include "core/common/gsl.h"
#include "core/framework/allocator.h"
+#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
@@ -49,39 +50,52 @@ size_t GetAttentionWorkspaceSize(
template
struct AttentionData {
- T* gemm_buffer;
- const T* bias;
+ T* gemm_buffer = nullptr;
+ const T* bias = nullptr;
- const T* query;
- const T* key;
- const T* value;
- const int* mask_index;
+ const T* query = nullptr;
+ const T* key = nullptr;
+ const T* value = nullptr;
+ const int* mask_index = nullptr;
gsl::span mask_index_dims;
- const T* past;
- const T* past_key;
- const T* past_value;
- const T* relative_position_bias;
-
- bool has_qkv_workspace;
- T* workspace;
- T* temp_k_workspace;
- T* temp_v_workspace;
-
- T* output;
- T* present;
- T* present_key;
- T* present_value;
-
- void* fused_runner;
- const void* fused_cross_attention_kernel;
-
- bool use_flash_attention;
- bool use_memory_efficient_attention;
-
- mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache;
- mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache;
+ const T* past = nullptr;
+ const T* past_key = nullptr;
+ const T* past_value = nullptr;
+ const T* relative_position_bias = nullptr;
+
+ bool has_qkv_workspace = false;
+ T* workspace = nullptr;
+ T* temp_k_workspace = nullptr;
+ T* temp_v_workspace = nullptr;
+
+ T* output = nullptr;
+ T* present = nullptr;
+ T* present_key = nullptr;
+ T* present_value = nullptr;
+
+ void* fused_runner = nullptr;
+ const void* fused_cross_attention_kernel = nullptr;
+
+ bool use_flash_attention = false;
+ bool use_memory_efficient_attention = false;
+
+ mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr;
+ mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr;
+
+ // Intermediate data
+ T* q = nullptr;
+ T* k = nullptr;
+ T* v = nullptr;
+ T* scratch = nullptr;
+ AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
};
+template
+Status PrepareQkv(contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ cudaStream_t stream,
+ int max_threads_per_block);
+
template
Status QkvToContext(
const cudaDeviceProp& device_prop,
@@ -90,33 +104,6 @@ Status QkvToContext(
contrib::AttentionParameters& parameters,
AttentionData& data);
-Status LaunchDecoderAttentionKernel(
- const cudaDeviceProp& prop, // Device Properties
- Stream* stream, // ORT Stream
- cublasHandle_t& cublas, // Cublas handle
- const size_t element_size, // Element size of input tensor
- const int batch_size, // Batch size (B)
- const int sequence_length, // Sequence length (S)
- const int kv_sequence_length, // Key/Value/Cache sequence length
- const int num_heads, // Number of attention heads (N)
- const int head_size, // Hidden size per head (H)
- const bool static_kv, // Whether cross attention or not
- const bool use_past, // Whether use cache or not
- const bool has_layer_state, // Whether output cache or not
- const bool has_key_padding_mask, // Whether use key_padding_mask or not
- const float mask_filter_value, // Mask filter value
- const void* gemm_query_buffer, // Query buffer
- const void* gemm_kv_buffer, // Key and value buffer
- const bool* key_padding_mask, // Key padding mask
- const void* key_cache, // Input key cache
- const void* value_cache, // Input value cache
- void* qkv_buffer, // Temporary buffer
- void* workspace_buffer, // Temporary buffer
- void* output, // Output tensor
- void* new_key_cache, // New_key_cache tensor
- void* new_value_cache // New_value_cache tensor
-);
-
// BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true)
Status LaunchTransCtx(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
@@ -161,33 +148,32 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream,
const half* tensor_add,
half* tensor_out);
-Status LaunchConcatPastToPresent(cudaStream_t stream,
- const int all_sequence_length,
- const int sequence_length,
- const int batch_size,
- const int head_size,
- const int num_heads,
- const int max_threads_per_block,
- const float* past,
- const float* k_v,
- float* present);
-
-Status LaunchConcatPastToPresent(cudaStream_t stream,
- const int all_sequence_length,
- const int sequence_length,
- const int batch_size,
- const int head_size,
- const int num_heads,
- const int max_threads_per_block,
- const half* past,
- const half* k_v,
- half* present);
+template
+Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
+ int sequence_length, int total_sequence_length, bool pass_past_in_kv,
+ cudaStream_t stream,
+ int max_threads_per_block,
+ AttentionData& data);
+
+template
+Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
+ const int max_sequence_length,
+ const int past_sequence_length,
+ const int sequence_length,
+ const int batch_size,
+ const int head_size,
+ const int num_heads,
+ const int max_threads_per_block,
+ const T* biases,
+ const T* qkv_buffer,
+ T* present);
template
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block);
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu
new file mode 100644
index 0000000000000..89be0f1115f41
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu
@@ -0,0 +1,466 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cuda/bert/attention_impl.h"
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/cu_inc/common.cuh"
+
+using namespace onnxruntime::cuda;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+__global__ void ConcatTensorToTensor(const int tensor_add_sequence_length,
+ const T* tensor_in,
+ const T* tensor_add,
+ T* tensor_out) {
+ const int h = threadIdx.x;
+ const int n = threadIdx.y;
+ const int s = blockIdx.x;
+ const int b = blockIdx.y;
+ const int chunk_id = blockIdx.z;
+
+ const int all_sequence_length = gridDim.x;
+ const int batch_size = gridDim.y;
+ const int num_heads = blockDim.y;
+ const int H = blockDim.x;
+
+ // K: number of identical tensors
+ // tensor_in: K x BxNxPxH
+ // tensor_add: K x BxNxLxH
+ // tensor_out: K x BxNxTxH, where T = P + L
+ const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length;
+
+ const int present_SH = all_sequence_length * H;
+ const int present_NSH = num_heads * present_SH;
+ int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size);
+ if (s < tensor_in_sequence_length) {
+ const int past_SH = tensor_in_sequence_length * H;
+ const int past_NSH = num_heads * past_SH;
+ const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size);
+ tensor_out[out_offset] = tensor_in[in_offset];
+ } else if (s < all_sequence_length) {
+ const int SH = tensor_add_sequence_length * H;
+ const int NSH = num_heads * SH;
+ const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size);
+ tensor_out[out_offset] = tensor_add[in_offset];
+ }
+}
+
+template
+__global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length,
+ const int H,
+ const T* tensor_in,
+ const T* tensor_add,
+ T* tensor_out) {
+ // Use when (H*)*num_heads > 1024
+ int h = threadIdx.x;
+ const int n = threadIdx.y;
+ const int s = blockIdx.x;
+ const int b = blockIdx.y;
+ const int chunk_id = blockIdx.z;
+
+ const int all_sequence_length = gridDim.x;
+ const int batch_size = gridDim.y;
+ const int num_heads = blockDim.y;
+ const int stride = blockDim.x;
+
+ // K: number of identical tensor
+ // tensor_in: K x BxNxPxH
+ // tensor_add: K x BxNxLxH
+ // tensor_out: K x BxNxTxH
+ const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length;
+
+ const int present_SH = all_sequence_length * H;
+ const int present_NSH = num_heads * present_SH;
+ while (h < H) {
+ int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size);
+ if (s < tensor_in_sequence_length) {
+ const int past_SH = tensor_in_sequence_length * H;
+ const int past_NSH = num_heads * past_SH;
+ const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size);
+ tensor_out[out_offset] = tensor_in[in_offset];
+ } else if (s < all_sequence_length) {
+ const int SH = tensor_add_sequence_length * H;
+ const int NSH = num_heads * SH;
+ const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size);
+ tensor_out[out_offset] = tensor_add[in_offset];
+ }
+
+ h += stride;
+ }
+}
+
+Status LaunchConcatTensorToTensor(cudaStream_t stream,
+ const int all_sequence_length,
+ const int sequence_length,
+ const int batch_size,
+ const int head_size,
+ const int num_heads,
+ const int max_threads_per_block,
+ const int matrix_num,
+ const float* tensor_in,
+ const float* tensor_add,
+ float* tensor_out) {
+ const dim3 grid(all_sequence_length, batch_size, matrix_num);
+ if (0 == (head_size & 1)) {
+ const int H = head_size / 2;
+ if (H * num_heads <= max_threads_per_block) {
+ const dim3 block(H, num_heads, 1);
+ ConcatTensorToTensor<<>>(sequence_length,
+ reinterpret_cast