From ef2ccc477b53ab1300f03cd8ae3e0c0211fb02c9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 15 Aug 2024 21:32:10 -0700 Subject: [PATCH 1/6] [js/web] Add support for int4/uint4 tensor (#21720) ### Description Add support for int4/uint4 tensor. --- js/common/lib/tensor-impl-type-mapping.ts | 2 + js/common/lib/tensor-impl.ts | 23 +++-- js/common/lib/tensor.ts | 4 + js/web/lib/wasm/jsep/init.ts | 11 +-- .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 5 +- js/web/lib/wasm/wasm-common.ts | 49 +++++++++- js/web/lib/wasm/wasm-core-impl.ts | 19 ++-- .../data/ops/dequantize-linear_int4.jsonc | 72 +++++++++++++++ js/web/test/op-test-schema.json | 16 +++- js/web/test/test-runner.ts | 89 +++++++++---------- 10 files changed, 206 insertions(+), 84 deletions(-) create mode 100644 js/web/test/data/ops/dequantize-linear_int4.jsonc diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index 8e68ba31348ca..14dbdca707220 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -28,6 +28,8 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map new TensorViewImpl(this.module, dataType, this.output(index, dims), dims); const createTemporaryOutput = (dataType: number, dims: readonly number[]): TensorView => { - const elementSize = getTensorElementSize(dataType); - if (!elementSize) { + const bufferSize = calculateTensorSizeInBytes(dataType, dims); + if (!bufferSize) { throw new Error(`Unsupported data type: ${dataType}`); } - const bufferSize = elementSize * ShapeUtil.size(dims); const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0; return new TensorViewImpl(this.module, dataType, gpuDataId, dims); }; @@ -245,9 +244,7 @@ export const init = async ( LOG_DEBUG( 'verbose', () => - `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ - contextDataOffset - }`, + `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, ); const context = new ComputeContextImpl(module, backend, contextDataOffset); return backend.computeKernel(kernel, context, errors); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 121ac8baff04b..b63d253ebbb29 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { DataType, getTensorElementSize } from '../../../wasm-common'; +import { calculateTensorSizeInBytes, DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; @@ -77,8 +77,7 @@ export const createMatMulNBitsProgramInfo = ( const outputNumber = getMaxComponents(dimAOuter); const aComponents = getMaxComponents(attributes.k); const bComponents = getMaxComponents(blobSizeInWords); - const elementSize = getTensorElementSize(dataType)!; - const workgroupOutputSize = dimAOuter * nBlocksPerCol * elementSize; + const workgroupOutputSize = calculateTensorSizeInBytes(dataType, dimAOuter * nBlocksPerCol)!; const maxNumberOfComponents = Math.floor(maxComputeWorkgroupStorageSize / workgroupOutputSize); const useBlockwiseMatMulNBits = nBlocksPerCol <= maxComputeWorkgroupSizes[0] && maxNumberOfComponents > 0; const components = diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index 1ef0630d04c8a..fd5d93675154c 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -32,6 +32,10 @@ export const enum DataType { complex64 = 14, complex128 = 15, bfloat16 = 16, + + // 4-bit data-types + uint4 = 21, + int4 = 22, } /** @@ -65,6 +69,10 @@ export const tensorDataTypeStringToEnum = (type: string): DataType => { return DataType.int64; case 'uint64': return DataType.uint64; + case 'int4': + return DataType.int4; + case 'uint4': + return DataType.uint4; default: throw new Error(`unsupported data type: ${type}`); @@ -102,6 +110,10 @@ export const tensorDataTypeEnumToString = (typeProto: DataType): Tensor.Type => return 'int64'; case DataType.uint64: return 'uint64'; + case DataType.int4: + return 'int4'; + case DataType.uint4: + return 'uint4'; default: throw new Error(`unsupported data type: ${typeProto}`); @@ -109,11 +121,42 @@ export const tensorDataTypeEnumToString = (typeProto: DataType): Tensor.Type => }; /** - * get tensor element size in bytes by the given data type + * get tensor size in bytes by the given data type and dimensions * @returns size in integer or undefined if the data type is not supported */ -export const getTensorElementSize = (dateType: number): number | undefined => - [undefined, 4, 1, 1, 2, 2, 4, 8, undefined, 1, 2, 8, 4, 8, undefined, undefined, undefined][dateType]; +export const calculateTensorSizeInBytes = ( + dateType: number, + dimsOrSize: readonly number[] | number, +): number | undefined => { + const elementSize = [ + -1, // undefined = 0 + 4, // float = 1 + 1, // uint8 = 2 + 1, // int8 = 3 + 2, // uint16 = 4 + 2, // int16 = 5 + 4, // int32 = 6 + 8, // int64 = 7 + -1, // string = 8 + 1, // bool = 9 + 2, // float16 = 10 + 8, // double = 11 + 4, // uint32 = 12 + 8, // uint64 = 13 + -1, // complex64 = 14 + -1, // complex128 = 15 + -1, // bfloat16 = 16 + -1, // FLOAT8E4M3FN = 17 + -1, // FLOAT8E4M3FNUZ = 18 + -1, // FLOAT8E5M2 = 19 + -1, // FLOAT8E5M2FNUZ = 20 + 0.5, // uint4 = 21 + 0.5, // int4 = 22 + ][dateType]; + + const size = typeof dimsOrSize === 'number' ? dimsOrSize : dimsOrSize.reduce((a, b) => a * b, 1); + return elementSize > 0 ? Math.ceil(size * elementSize) : undefined; +}; /** * get typed array constructor by the given tensor type diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 8f72a8fcda1c3..6c4e28df62f23 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -17,8 +17,8 @@ import { import { setRunOptions } from './run-options'; import { setSessionOptions } from './session-options'; import { + calculateTensorSizeInBytes, dataLocationStringToEnum, - getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, @@ -360,9 +360,7 @@ export const createSession = async ( } if (enableGraphCapture && location !== 'gpu-buffer') { throw new Error( - `Not supported preferred output location: ${ - location - }. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`, + `Not supported preferred output location: ${location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`, ); } outputPreferredLocations.push(location); @@ -474,8 +472,7 @@ export const prepareInputOutputTensor = ( if (location === 'gpu-buffer') { const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; - const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; - dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; const registerBuffer = wasm.jsepRegisterBuffer; if (!registerBuffer) { @@ -611,9 +608,7 @@ export const run = async ( if (inputNamesUTF8Encoded.length !== inputCount) { throw new Error( - `input count from feeds (${ - inputCount - }) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`, + `input count from feeds (${inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`, ); } @@ -752,8 +747,8 @@ export const run = async ( throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.'); } const gpuBuffer = getBuffer(dataOffset); - const elementSize = getTensorElementSize(dataType); - if (elementSize === undefined || !isGpuBufferSupportedType(type)) { + const bufferSize = calculateTensorSizeInBytes(dataType, size); + if (bufferSize === undefined || !isGpuBufferSupportedType(type)) { throw new Error(`Unsupported data type: ${type}`); } @@ -765,7 +760,7 @@ export const run = async ( dims, { gpuBuffer, - download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type), + download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type), dispose: () => { wasm._OrtReleaseTensor(tensor); }, diff --git a/js/web/test/data/ops/dequantize-linear_int4.jsonc b/js/web/test/data/ops/dequantize-linear_int4.jsonc new file mode 100644 index 0000000000000..e285b1bcdf64c --- /dev/null +++ b/js/web/test/data/ops/dequantize-linear_int4.jsonc @@ -0,0 +1,72 @@ +[ + { + "name": "DequantizeLinear int4", + "opset": { "domain": "", "version": 21 }, + "operator": "DequantizeLinear", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "T[2,3]", + "inputs": [ + { + "data": [0, 1, 7, -4, -8], + "dims": [5], + "type": "int4" + }, + { + "data": [2], + "dims": [], + "type": "float32" + }, + { + "data": [1], + "dims": [1], + "type": "int4" + } + ], + "outputs": [ + { + "data": [-2, 0, 12, -10, -18], + "dims": [5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "DequantizeLinear uint4", + "opset": { "domain": "", "version": 21 }, + "operator": "DequantizeLinear", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "T[2,3]", + "inputs": [ + { + "data": [0, 1, 7, 10, 15], + "dims": [5], + "type": "uint4" + }, + { + "data": [2], + "dims": [], + "type": "float32" + }, + { + "data": [1], + "dims": [1], + "type": "uint4" + } + ], + "outputs": [ + { + "data": [-2, 0, 12, 18, 28], + "dims": [5], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/op-test-schema.json b/js/web/test/op-test-schema.json index 0a0a691c37022..948efc6b09f6b 100644 --- a/js/web/test/op-test-schema.json +++ b/js/web/test/op-test-schema.json @@ -189,7 +189,9 @@ "uint32", "uint64", "bool", - "string" + "string", + "int4", + "uint4" ] }, "data": { @@ -226,7 +228,9 @@ "uint32", "uint64", "bool", - "string" + "string", + "int4", + "uint4" ] }, "data": { @@ -261,7 +265,9 @@ "uint32", "uint64", "bool", - "string" + "string", + "int4", + "uint4" ] }, "data": { @@ -298,7 +304,9 @@ "uint32", "uint64", "bool", - "string" + "string", + "int4", + "uint4" ] }, "data": { diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 84f3d8d9fca2b..aa9555c191501 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -16,7 +16,11 @@ import { onnx } from '../lib/onnxjs/ort-schema/protobuf/onnx'; import { Tensor } from '../lib/onnxjs/tensor'; import { ProtoUtil } from '../lib/onnxjs/util'; import { createView } from '../lib/wasm/jsep/tensor-view'; -import { getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEnum } from '../lib/wasm/wasm-common'; +import { + calculateTensorSizeInBytes, + isGpuBufferSupportedType, + tensorDataTypeStringToEnum, +} from '../lib/wasm/wasm-common'; import { base64toBuffer, createMockGraph, readFile } from './test-shared'; import { Test } from './test-types'; @@ -372,9 +376,7 @@ export class TensorResultValidator { if (!match) { Logger.error( 'TestRunner', - `Tensor mismatch: \nACTUAL: type=${actual[i].type}; dims=[${actual[i].dims}]; data=[${ - actual[i].data - }]\nEXPECT: type=${expected[i].type}; dims=[${expected[i].dims}]; data=[${expected[i].data}]`, + `Tensor mismatch: \nACTUAL: type=${actual[i].type}; dims=[${actual[i].dims}]; data=[${actual[i].data}]\nEXPECT: type=${expected[i].type}; dims=[${expected[i].dims}]; data=[${expected[i].data}]`, ); } expect(match, 'tensor data should match').to.be.true; @@ -462,6 +464,8 @@ export class TensorResultValidator { case 'uint32': case 'int64': case 'bool': + case 'int4': + case 'uint4': return TensorResultValidator.integerEqual( actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, @@ -586,8 +590,7 @@ function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[] throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`); } - const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(type))!; - const size = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + const size = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!; const device = ort.env.webgpu.device as GPUDevice; const gpuBuffer = device.createBuffer({ @@ -852,22 +855,14 @@ export class ProtoOpTestContext { for (let i = 0; i < inputCount; i++) { if (inputsOmitted[i] !== !testCase.inputs![i].data) { throw new Error( - `Test cases for test: ${test.name} [${ - test.operator - }] must have consistent inputs data availability. Data of input[${i}] in testCase #0 and #${ - caseIndex - } should be both available or both omitted.`, + `Test cases for test: ${test.name} [${test.operator}] must have consistent inputs data availability. Data of input[${i}] in testCase #0 and #${caseIndex} should be both available or both omitted.`, ); } } for (let i = 0; i < outputCount; i++) { if (outputsOmitted[i] !== !testCase.outputs![i].data) { throw new Error( - `Test cases for test: ${test.name} [${ - test.operator - }] must have consistent outputs data availability. Data of output[${ - i - }] in testCase #0 and #${caseIndex} should be both available or both omitted.`, + `Test cases for test: ${test.name} [${test.operator}] must have consistent outputs data availability. Data of output[${i}] in testCase #0 and #${caseIndex} should be both available or both omitted.`, ); } } @@ -898,9 +893,7 @@ export class ProtoOpTestContext { // check if all test cases have data if (test.cases.some((testCase) => testCase.inputs!.some((input) => !input.data || !input.dims))) { throw new Error( - `Test cases for test: ${test.name} [${ - test.operator - }] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`, + `Test cases for test: ${test.name} [${test.operator}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`, ); } @@ -919,18 +912,14 @@ export class ProtoOpTestContext { ) ) { throw new Error( - `Test cases for test: ${test.name} [${ - test.operator - }] must have the same rank for each inputs in different test cases`, + `Test cases for test: ${test.name} [${test.operator}] must have the same rank for each inputs in different test cases`, ); } } else if (test.inputShapeDefinitions === 'static') { // check if all test cases have data if (test.cases.some((testCase) => testCase.inputs!.some((input) => !input.data || !input.dims))) { throw new Error( - `Test cases for test: ${test.name} [${ - test.operator - }] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`, + `Test cases for test: ${test.name} [${test.operator}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`, ); } @@ -946,9 +935,7 @@ export class ProtoOpTestContext { ) ) { throw new Error( - `Test cases for test: ${test.name} [${ - test.operator - }] must have the same shape for each inputs in different test cases`, + `Test cases for test: ${test.name} [${test.operator}] must have the same shape for each inputs in different test cases`, ); } } else { @@ -1033,18 +1020,33 @@ async function runProtoOpTestcase( ): Promise { const feeds: Record = {}; const fetches: Record> = {}; + + const createTensor = (type: ort.Tensor.Type, data: number[], dims: readonly number[]): ort.Tensor => { + let buffer: number[] | BigUint64Array | BigInt64Array | Uint16Array | Uint8Array = data; + if (type === 'uint64') { + buffer = BigUint64Array.from(data.map(BigInt)); + } else if (type === 'int64') { + buffer = BigInt64Array.from(data.map(BigInt)); + } else if (type === 'float16') { + const dataArr = Float16ArrayPolyfill.from(data); + buffer = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2); + } else if (type === 'uint4' || type === 'int4') { + buffer = new Uint8Array(calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!); + // encode (u)int4 data into Uint8Array + for (let j = 0; j < data.length; j++) { + /* eslint-disable no-bitwise */ + const byteIndex = j >> 1; + const bitOffset = (j & 1) << 2; + buffer[byteIndex] |= data[j] << bitOffset; + /* eslint-enable no-bitwise */ + } + } + return new ort.Tensor(type, buffer, dims); + }; + testCase.inputs.forEach((input, i) => { if (input.data) { - let data: number[] | BigUint64Array | BigInt64Array | Uint16Array = input.data; - if (input.type === 'uint64') { - data = BigUint64Array.from(input.data.map(BigInt)); - } else if (input.type === 'int64') { - data = BigInt64Array.from(input.data.map(BigInt)); - } else if (input.type === 'float16') { - const dataArr = Float16ArrayPolyfill.from(input.data); - data = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2); - } - feeds[`input_${i}`] = new ort.Tensor(input.type, data, input.dims); + feeds[`input_${i}`] = createTensor(input.type, input.data, input.dims); } }); @@ -1052,16 +1054,7 @@ async function runProtoOpTestcase( const expectedOutputNames: string[] = []; testCase.outputs.forEach((output, i) => { if (output.data) { - let data: number[] | BigUint64Array | BigInt64Array | Uint16Array = output.data; - if (output.type === 'uint64') { - data = BigUint64Array.from(output.data.map(BigInt)); - } else if (output.type === 'int64') { - data = BigInt64Array.from(output.data.map(BigInt)); - } else if (output.type === 'float16') { - const dataArr = Float16ArrayPolyfill.from(output.data); - data = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2); - } - outputs.push(new ort.Tensor(output.type, data, output.dims)); + outputs.push(createTensor(output.type, output.data, output.dims)); expectedOutputNames.push(`output_${i}`); fetches[`output_${i}`] = { dims: output.dims, type: output.type }; } From c97cc5c1b0dacefc125af4aa9a37c4020c9c2ee2 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 16 Aug 2024 15:51:50 +1000 Subject: [PATCH 2/6] Put all external project targets under the 'External' folder in VS (#21765) ### Description Handle targets in subdirectories for external projects. All targets will now go in a per-project folder under 'External' e.g. gmock and gtest now get handled correctly and are under External/googletest vs. existing setup where they ended up as top-level projects. ![image](https://github.com/user-attachments/assets/99ec259c-47cd-44f3-954d-58569c941cc2) ### Motivation and Context Improve developer experience. --- cmake/external/helper_functions.cmake | 46 +++++++++++++++++++-------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/cmake/external/helper_functions.cmake b/cmake/external/helper_functions.cmake index eefb3ba2e800a..e3f2211f96158 100644 --- a/cmake/external/helper_functions.cmake +++ b/cmake/external/helper_functions.cmake @@ -1,6 +1,19 @@ # Distributed under the OSI-approved BSD 3-Clause License. See accompanying # file Copyright.txt or https://cmake.org/licensing for details. +# Recursively set the folder for all targets in the subdirectories of the given source directory. +function(set_folder_for_subdir_targets srcDir folderName) + get_property(subdirs DIRECTORY "${srcDir}" PROPERTY SUBDIRECTORIES) + foreach(subdir ${subdirs}) + get_property(subdir_import_targets DIRECTORY "${subdir}" PROPERTY BUILDSYSTEM_TARGETS) + foreach(subdir_target ${subdir_import_targets}) + set_target_properties(${subdir_target} PROPERTIES FOLDER ${folderName}) + endforeach() + + set_folder_for_subdir_targets(${subdir} ${folderName}) + endforeach() +endfunction() + # This file was copied from cmake source with modifications: # 1. Add the EXCLUDE_FROM_ALL keyword when this function calls add_subdirectory. It will also resolve the # 'make install' issue. @@ -165,23 +178,28 @@ macro(onnxruntime_fetchcontent_makeavailable) else() add_subdirectory(${__cmake_srcdir} ${${__cmake_contentNameLower}_BINARY_DIR} EXCLUDE_FROM_ALL) endif() - get_property(subdir_import_targets DIRECTORY "${__cmake_srcdir}" PROPERTY BUILDSYSTEM_TARGETS) - foreach(subdir_target ${subdir_import_targets}) - if(TARGET ${subdir_target}) - get_target_property(subdir_target_type ${subdir_target} TYPE) - if(subdir_target_type STREQUAL "EXECUTABLE") - get_target_property(subdir_target_osx_arch ${subdir_target} OSX_ARCHITECTURES) - if (subdir_target_osx_arch) - if (NOT ${CMAKE_HOST_SYSTEM_PROCESSOR} IN_LIST subdir_target_osx_arch) - message("Added an executable target ${subdir_target} but it can not run natively on ${CMAKE_HOST_SYSTEM_PROCESSOR}, we will try to modify it") - endif() + + get_property(subdir_import_targets DIRECTORY "${__cmake_srcdir}" PROPERTY BUILDSYSTEM_TARGETS) + + foreach(subdir_target ${subdir_import_targets}) + if(TARGET ${subdir_target}) + get_target_property(subdir_target_type ${subdir_target} TYPE) + if(subdir_target_type STREQUAL "EXECUTABLE") + get_target_property(subdir_target_osx_arch ${subdir_target} OSX_ARCHITECTURES) + if (subdir_target_osx_arch) + if (NOT ${CMAKE_HOST_SYSTEM_PROCESSOR} IN_LIST subdir_target_osx_arch) + message("Added an executable target ${subdir_target} but it can not run natively on ${CMAKE_HOST_SYSTEM_PROCESSOR}, we will try to modify it") endif() endif() - set_target_properties(${subdir_target} PROPERTIES FOLDER "External") - set_target_properties(${subdir_target} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) endif() - endforeach() - set(CMAKE_SKIP_INSTALL_RULES FALSE) + set_target_properties(${subdir_target} PROPERTIES FOLDER "External/${__cmake_contentName}") + set_target_properties(${subdir_target} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) + endif() + endforeach() + set(CMAKE_SKIP_INSTALL_RULES FALSE) + + # set the FOLDER property for all targets contained in source directory and subfolders + set_folder_for_subdir_targets(${__cmake_srcdir} "External/${__cmake_contentName}") endif() unset(__cmake_srcdir) From b2d603abdaf580872d3249455241f53d64d5a2a9 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 16 Aug 2024 13:59:51 +0800 Subject: [PATCH 3/6] [WebNN EP] Remove workaround for scalar (#21704) Currently Chromium has supported scalar with dims = {}, remove legacy workaround for supporting scalar. --- .../core/providers/webnn/builders/model.cc | 4 ---- .../core/providers/webnn/builders/model.h | 8 ------- .../providers/webnn/builders/model_builder.cc | 22 +++++-------------- .../providers/webnn/builders/model_builder.h | 5 +---- .../webnn/webnn_execution_provider.cc | 10 --------- 5 files changed, 6 insertions(+), 43 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index ef807a8c4fa26..8cd2e8d0ffad3 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -142,10 +142,6 @@ Status Model::Predict(const InlinedHashMap& inputs, return Status::OK(); } -bool Model::IsScalarOutput(const std::string& output_name) const { - return Contains(scalar_outputs_, output_name); -} - const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const { return input_output_info_.at(name); } diff --git a/onnxruntime/core/providers/webnn/builders/model.h b/onnxruntime/core/providers/webnn/builders/model.h index 4af82a2675691..5119dbbbc9858 100644 --- a/onnxruntime/core/providers/webnn/builders/model.h +++ b/onnxruntime/core/providers/webnn/builders/model.h @@ -34,8 +34,6 @@ class Model { onnxruntime::common::Status Predict(const InlinedHashMap& inputs, const InlinedHashMap& outputs); - bool IsScalarOutput(const std::string& output_name) const; - // Mutex for exclusive lock to this model object. OrtMutex& GetMutex() { return mutex_; } @@ -65,8 +63,6 @@ class Model { emscripten::val wnn_inputs_ = emscripten::val::object(); emscripten::val wnn_outputs_ = emscripten::val::object(); - InlinedHashSet scalar_outputs_; - std::vector inputs_; std::vector outputs_; @@ -83,10 +79,6 @@ class Model { input_output_info_ = std::move(input_output_info); } - void SetScalarOutputs(InlinedHashSet&& scalar_outputs) { - scalar_outputs_ = std::move(scalar_outputs); - } - void AllocateInputOutputBuffers(); }; diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index b21f717eedc7a..44bec1fb6fd48 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -104,13 +104,15 @@ Status ModelBuilder::RegisterInitializers() { emscripten::val operand = emscripten::val::object(); if (IsSupportedDataType(data_type, webnn_supported_data_types)) { ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); - auto num_elements = SafeInt(Product(tensor.dims())); + auto num_elements = SafeInt(Product(shape)); emscripten::val view = emscripten::val::undefined(); std::byte* tensor_ptr = nullptr; if (tensor.has_raw_data()) { tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); } else { - std::vector unpacked_tensor; + // Store temporary unpacked_tensor. + unpacked_tensors_.push_back({}); + std::vector& unpacked_tensor = unpacked_tensors_.back(); ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); tensor_ptr = reinterpret_cast(unpacked_tensor.data()); } @@ -187,16 +189,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i ORT_RETURN_IF(shape_proto == nullptr, "shape_proto cannot be null for ", input_output_type, ": ", name); const auto& shape = shape_proto->dim(); - if (shape.empty()) { - // If we have an empty shape, this is a scalar input. - dims.push_back(1); - - // We need to change the shapes of these scalar outputs back to {} - // when WebNN EP returns these values to ORT. - if (!is_input) { - AddScalarOutput(name); - } - } else { + if (!shape.empty()) { dims.reserve(shape.size()); for (const auto& dim : shape) { // dim_param free dimensions should have already been excluded by IsInputSupported(). @@ -343,7 +336,6 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_)); model->SetInputs(std::move(input_names_)); model->SetOutputs(std::move(output_names_)); - model->SetScalarOutputs(std::move(scalar_outputs_)); model->SetInputOutputInfo(std::move(input_output_info_)); // Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews // for inputs and outputs because they will be transferred after compute() done. @@ -352,10 +344,6 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { return Status::OK(); } -void ModelBuilder::AddScalarOutput(const std::string& output_name) { - scalar_outputs_.insert(output_name); -} - void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& operand) { wnn_operands_.insert(std::make_pair(name, operand)); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index b1561f009aa25..2d686070cdcc1 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -69,8 +69,8 @@ class ModelBuilder { InlinedHashMap wnn_operands_; std::vector input_names_; std::vector output_names_; + std::vector> unpacked_tensors_; - InlinedHashSet scalar_outputs_; InlinedHashMap input_output_info_; InlinedHashSet skipped_initializers_; @@ -92,9 +92,6 @@ class ModelBuilder { Status RegisterModelOutputs() ORT_MUST_USE_RESULT; Status RegisterModelInputOutput(const NodeArg& node_arg, bool is_input) ORT_MUST_USE_RESULT; - // Record the onnx scalar output names. - void AddScalarOutput(const std::string& output_name); - static const IOpBuilder* GetOpBuilder(const Node& node); }; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 1cd382c1e75e9..b918daf838c99 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -272,10 +272,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector(input_tensor.GetTensorRawData()); inputs.emplace( input_name, @@ -297,12 +293,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vectorGetInputOutputInfo(output_name); auto output_shape = output_info.shape; auto output_type = output_info.data_type; - - // Since WebNN EP use {1} tensor as scalar, if the model output should have empty shape. - // We are going to replace the {1} shape of the output back to {}. - if (model->IsScalarOutput(output_name)) - output_shape.clear(); - auto output_tensor = ctx.GetOutput(i, output_shape.data(), output_shape.size()); From a4bec3d3745576eefd37b3c12d51e6c8c26bd52e Mon Sep 17 00:00:00 2001 From: Emmanuel <91394589+lainey1570@users.noreply.github.com> Date: Fri, 16 Aug 2024 10:45:22 -0700 Subject: [PATCH 4/6] Enabled Dynamo exporter (#21713) ### Description This PR modifies the run_dynamo_export function to ensure it mirrors the behavior of run_torchscript_merged_export rather than run_torchscript_separate_export. Additionally, I made adjustments to the main function to ensure that run_dynamo is correctly invoked. ### Motivation and Context The main motivation for this change is to enable successful export of LLaMA-2 and LLaMA-3 models using the Dynamo exporter to ONNX. Previously, the exporter was saving two copies of the weights, which is inefficient. The modified approach ensures that only one copy of the weights is saved, and the model can support both scenarios. These changes enhance the compatibility of the exporter with LLaMA models and subsequently other models and optimize the export process --- .../models/llama/convert_to_onnx.py | 83 ++++++------------- 1 file changed, 25 insertions(+), 58 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index f701e465b9153..f5446ed718087 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -11,6 +11,7 @@ import shutil import subprocess import sys +import tempfile from itertools import chain import onnx @@ -113,34 +114,6 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st ) -# Notes: -# 1) Dynamo export will not work automatically until this issue is resolved: https://github.com/microsoft/onnxscript/issues/493 -# -# 2) Dynamo export will run manually if you set the ONNX file path to the same path that you use to save the model after export. -# In other words, the value of `temp_path` should be set as the ONNX file path. You can open the issue in your browser to find -# the location in ONNX Script where you have to make this change. -# -# Once the issue is resolved, we hope to modify the code below as follows for each export. -# -# Before: -# temp_dir = args.output -# temp_path = os.path.join(temp_dir, "temp.onnx") -# ... -# ... -# ... -# del onnx_model -# os.system(f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}") -# -# -# After: -# temp_dir = tempfile.TemporaryDirectory() -# temp_path = os.path.join(temp_dir.name, "temp.onnx") -# ... -# ... -# ... -# del onnx_model -# temp_dir.cleanup() -# def run_dynamo_export( args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1 ): @@ -149,35 +122,25 @@ def run_dynamo_export( config.capture_scalar_outputs = True # Dummy values for export - batch_size, sequence_length = 2, 8 - device = torch.device("cpu") - - # Export decoder_model.onnx - input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length) - temp_dir = args.output # tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx") - torch.onnx.dynamo_export( - llama, input_ids, attn_mask, pos_ids, export_options=torch.onnx.ExportOptions(dynamic_shapes=True) - ).save(temp_path) - - # Check decoder_model.onnx and save all external data to one file - onnx.checker.check_model(temp_path) - onnx.shape_inference.infer_shapes_path(temp_path) + batch_size, sequence_length, past_sequence_length = 2, 8, 0 + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") - output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") - onnx_model = onnx.load_model(temp_path, load_external_data=True) - save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data") - del onnx_model - os.system( - f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" - ) # temp_dir.cleanup() + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 # Export decoder_with_past_model.onnx - input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length, world_size=world_size + input_ids, attn_mask, pos_ids, past_kv = get_merged_sample_with_past_kv_inputs( + l_config, + device, + batch_size, + sequence_length, + past_sequence_length, + max_seq_len=max_sequence_length, + use_fp16=False, + world_size=world_size, ) - temp_dir = args.output # tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx") + temp_dir = tempfile.TemporaryDirectory() + temp_path = os.path.join(temp_dir.name, "temp.onnx") torch.onnx.dynamo_export( llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True) ).save(temp_path) @@ -190,9 +153,7 @@ def run_dynamo_export( onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data") del onnx_model - os.system( - f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" - ) # temp_dir.cleanup() + temp_dir.cleanup() logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!") @@ -869,7 +830,7 @@ def main(): # Export to ONNX if missing_separate_exports or missing_merged_export: - if args.use_dynamo_export and missing_separate_exports: + if args.use_dynamo_export: logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") @@ -902,7 +863,10 @@ def main(): decoder_merged_model_fp32_opt_path, ] - # Run the optimizer script + if args.use_dynamo_export: + continue + + # Run the optimizer script. logger.info("Optimizing models...") for orig_path, opt_path in zip(old_paths, new_paths): if os.path.exists(orig_path): @@ -1007,6 +971,9 @@ def main(): remove_existing_model(fp_path) barrier() + if args.use_dynamo_export: + return + logger.info("Verifying parity on all ONNX models created") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models From 63e88499923cc89bb01b9b6844ffe7bb79a0149d Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 16 Aug 2024 11:21:09 -0700 Subject: [PATCH 5/6] build_aar_package.py - Check that executable is present before trying to copy it. (#21730) Check that executable is present before trying to copy it. Accommodate builds where we skip building the test executables. --- tools/ci_build/github/android/build_aar_package.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/android/build_aar_package.py b/tools/ci_build/github/android/build_aar_package.py index ee76bab762552..036db703ccf27 100644 --- a/tools/ci_build/github/android/build_aar_package.py +++ b/tools/ci_build/github/android/build_aar_package.py @@ -118,11 +118,16 @@ def _build_aar(args): os.symlink(os.path.join(abi_build_dir, build_config, lib_name), target_lib_name) # copy executables for each abi, in case we want to publish those as well + # some of them might not exist, e.g., if we skip building the tests abi_exe_dir = os.path.join(exe_dir, abi) for exe_name in ["libonnxruntime.so", "onnxruntime_perf_test", "onnx_test_runner"]: + src_exe_path = os.path.join(abi_build_dir, build_config, exe_name) + if not os.path.exists(src_exe_path): + continue + os.makedirs(abi_exe_dir, exist_ok=True) - target_exe_name = os.path.join(abi_exe_dir, exe_name) - shutil.copyfile(os.path.join(abi_build_dir, build_config, exe_name), target_exe_name) + dest_exe_path = os.path.join(abi_exe_dir, exe_name) + shutil.copyfile(src_exe_path, dest_exe_path) # we only need to define the header files path once if not header_files_path: From d79e3c5791cf1b9481c8d2d64917dc5270658446 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 16 Aug 2024 15:40:04 -0700 Subject: [PATCH 6/6] Extend Attention Bias Broadcast Support (#21710) ### Description Previously, MultiHeadAttention supports relative position bias of shape [1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention supports [1, N, S, T]. This will extend the support to allow [1, N, S, T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs. - [x] Rename the input of "relative position bias" to "attention bias" because it can also be used for other types of bias, like ALiBi (Attention with Linear Biases) or attention mask. - [x] Update unfused kernel to support broadcasting 2nd dimension of attention bias. - [x] Update efficient attention to support broadcasting 2nd dimension of attention bias. - [x] Update operators (MultiHeadAttention, DecoderMaskedMultiHeadAttention, Attention, PackedAttention, PackedMultiHeadAttention) to support broadcast attention bias on CUDA and CPU EPs. - [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that those EPs do not support broadcasting attention_bias for now). - [x] Add attention bias tests for MultiHeadAttention. - [x] Update operator documents - [x] Update benchmark script Other changes: * Fix some checks in multihead-attention.ts * Add helper functions to dump tensors given dimensions. --- docs/ContribOperators.md | 28 +- docs/OperatorKernels.md | 22 +- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 44 +- .../jsep/webgpu/ops/multihead-attention.ts | 133 ++- .../test/data/ops/multihead-attention.jsonc | 36 +- onnxruntime/contrib_ops/cpu/bert/attention.cc | 6 +- .../contrib_ops/cpu/bert/attention_base.cc | 55 +- .../contrib_ops/cpu/bert/attention_base.h | 4 +- .../contrib_ops/cpu/bert/attention_common.h | 8 +- .../contrib_ops/cpu/bert/attention_cpu_base.h | 137 +-- .../cpu/bert/multihead_attention.cc | 10 +- .../cpu/bert/multihead_attention_helper.h | 80 +- .../cpu/quantization/attention_quant.cc | 2 +- .../contrib_ops/cpu/utils/console_dumper.h | 38 + .../contrib_ops/cpu/utils/dump_tensor.cc | 29 + .../contrib_ops/cpu/utils/dump_tensor.h | 5 + .../contrib_ops/cuda/bert/attention.cc | 19 +- .../contrib_ops/cuda/bert/attention_impl.cu | 53 +- .../contrib_ops/cuda/bert/attention_impl.h | 4 +- .../cuda/bert/attention_prepare_qkv.cu | 22 +- .../cuda/bert/attention_softmax.cu | 852 ++++++++++-------- .../contrib_ops/cuda/bert/attention_softmax.h | 13 +- .../bert/cutlass_fmha/fmha_launch_template.h | 18 +- .../cutlass_fmha/memory_efficient_attention.h | 11 +- .../cuda/bert/decoder_attention_impl.cu | 14 +- .../decoder_masked_multihead_attention.cc | 14 +- .../bert/decoder_masked_self_attention.cc | 8 +- ...decoder_masked_multihead_attention_impl.cu | 23 +- .../decoder_masked_multihead_attention_impl.h | 4 +- .../cuda/bert/group_query_attention_impl.cu | 1 - .../cuda/bert/multihead_attention.cc | 20 +- .../contrib_ops/cuda/bert/packed_attention.cc | 64 +- .../contrib_ops/cuda/bert/packed_attention.h | 6 +- .../cuda/bert/packed_attention_impl.cu | 16 +- .../cuda/bert/packed_attention_impl.h | 2 +- .../cuda/bert/packed_multihead_attention.cc | 69 +- .../cuda/bert/packed_multihead_attention.h | 2 +- .../bert/packed_multihead_attention_impl.cu | 13 +- .../bert/packed_multihead_attention_impl.h | 3 +- .../quantization/attention_quantization.cc | 2 +- .../qordered_ops/qordered_attention.cc | 2 +- .../qordered_attention_input_enum.h | 2 +- .../cuda/utils/dump_cuda_tensor.cc | 42 + .../contrib_ops/cuda/utils/dump_cuda_tensor.h | 14 +- .../contrib_ops/rocm/bert/attention.cu | 8 +- ...ed_gemm_softmax_gemm_permute_pipelines.cuh | 3 +- .../rocm/bert/multihead_attention.cu | 12 +- .../core/graph/contrib_ops/bert_defs.cc | 30 +- .../graph/contrib_ops/quantization_defs.cc | 6 +- .../core/providers/cpu/cpu_provider_shared.cc | 4 +- .../core/providers/cpu/cpu_provider_shared.h | 2 +- .../External/DirectMLHelpers/DirectMLSchema.h | 4 +- .../src/Operators/DmlOperatorAttention.cpp | 43 +- .../DmlOperatorMultiHeadAttention.cpp | 27 +- .../src/Operators/DmlOperatorQAttention.cpp | 4 +- .../provider_bridge_provider.cc | 4 +- .../python/tools/transformers/constants.py | 4 +- .../tools/transformers/convert_generation.py | 2 +- .../transformers/convert_to_packing_mode.py | 20 +- .../transformers/fusion_rotary_attention.py | 2 +- .../test/contrib_ops/attention_op_test.cc | 32 +- .../contrib_ops/attention_op_test_helper.cc | 138 +-- .../contrib_ops/attention_op_test_helper.h | 20 +- .../multihead_attention_op_test.cc | 54 +- .../multihead_attention_op_test_data_gen.py | 6 +- .../contrib_ops/packed_attention_op_test.cc | 28 +- .../packed_multihead_attention_op_test.cc | 76 +- .../contrib_ops/qordered_attention_test.cc | 2 +- .../python/transformers/benchmark_mha.cmd | 8 + .../test/python/transformers/benchmark_mha.py | 287 ++++-- .../test/python/transformers/benchmark_mha.sh | 9 + .../test/python/transformers/test_mha.py | 235 +++-- .../test_parity_neox_attention.py | 2 +- .../python/transformers/test_parity_t5_mha.py | 14 +- .../attention/attention_test_data.txt | 76 +- .../packed_multihead_attention_test_data.txt | 34 +- 76 files changed, 1791 insertions(+), 1355 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index c60b25f3418f6..0048190f9063b 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -180,8 +180,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size) or (3 * batch_size + 2)
past (optional) : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
-
relative_position_bias (optional) : T
-
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
+
attention_bias (optional) : T
+
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
past_sequence_length (optional) : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
@@ -1166,7 +1166,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Value with shape (batch_size, 1, v_hidden_size) for self attention or past_value with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention
mask_index (optional) : M
Mask values of shape (batch_size, total_sequence_length) or (batch_size, kv_sequence_length)
-
relative_position_bias (optional) : T
+
attention_bias (optional) : T
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
past_key (optional) : T
past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) for self attentionWhen past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size). The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.
@@ -1256,8 +1256,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Mask values of shape (batch_size, total_sequence_length)
past : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size). The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.
-
relative_position_bias (optional) : T
-
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
+
attention_bias (optional) : T
+
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
past_sequence_length : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
beam_width (optional) : M
@@ -3202,8 +3202,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)
-
relative_position_bias (optional) : T
-
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
+
attention_bias (optional) : T
+
bias added to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
past_key (optional) : T
past state for self attention key with shape (batch_size, num_heads, past_sequence_length, head_size)
past_value (optional) : T
@@ -3516,8 +3516,8 @@ This version of the operator has been available since version 1 of the 'com.micr
In packing mode, it specifies the offset of each token(batch_size, sequence_length).
cumulative_sequence_length : M
A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.
-
relative_position_bias (optional) : T
-
A tensor with shape (batch_size, num_heads, sequence_length, sequence_length)or (1, num_heads, sequence_length, sequence_length).It specifies the additional bias to QxK'
+
attention_bias (optional) : T
+
A tensor with shape (batch_size or 1, num_heads or 1, sequence_length, sequence_length).It specifies the additional bias to QxK'
#### Outputs @@ -3591,8 +3591,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Offset of each token before packing, with shape (batch_size, sequence_length).
cumulative_sequence_length : M
A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.
-
relative_position_bias (optional) : T
-
It specifies the additional bias to QxK'. The shape is (batch_size, num_heads, sequence_length, sequence_length) or (1, num_heads, sequence_length, sequence_length)
+
attention_bias (optional) : T
+
It specifies the additional bias to QxK'. The shape is (batch_size or 1, num_heads or 1, sequence_length, sequence_length)
#### Outputs @@ -4468,7 +4468,7 @@ This version of the operator has been available since version 1 of the 'com.micr left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. - Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. + Current version does not support past/present, attention_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. #### Version @@ -4533,8 +4533,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
past (optional) : Q
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).
-
relative_position_bias (optional) : S
-
additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).
+
attention_bias (optional) : S
+
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length).
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index f0aa332ff39eb..96173b5a4ea4a 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -460,7 +460,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| |BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| @@ -490,7 +490,7 @@ Do not modify directly.* |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| -|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| +|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcMaxPool|*in* x:**T**
*out* y:**T**|1+|**T** = tensor(int8), tensor(uint8)| @@ -848,7 +848,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| @@ -861,8 +861,8 @@ Do not modify directly.* |ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)| -|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**V**|1+|**T** = tensor(float), tensor(float16)| -|DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**V**|1+|**T** = tensor(float), tensor(float16)| +|DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)| |DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| |DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)| @@ -884,14 +884,14 @@ Do not modify directly.* |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| +|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)| -|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| +|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLongformerAttention|*in* input:**Q**
*in* scale_input:**S**
*in* weight:**Q**
*in* scale_weight:**S**
*in* bias:**S**
*in* scale_bias:**S**
*in* scale_qkv_gemm:**S**
*in* mask:**F**
*in* global_weight:**Q**
*in* scale_global_weight:**S**
*in* global_bias:**S**
*in* scale_global_gemm:**S**
*in* global:**G**
*in* scale_output:**S**
*out* output:**Q**|1+|**F** = tensor(float16)
**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| @@ -1296,7 +1296,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1312,7 +1312,7 @@ Do not modify directly.* |GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 0008fd1aff62e..8840ef97b4279 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -101,7 +101,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte // bias (Q/K/V) : (D + D + D_v) // mask_index : see below // past (K/V) : (2, B, N, P, H) or NULL - // relative_position_bias : (B, N, S, T) or NULL + // attention_bias : (B, N, S, T) or NULL // For mask_index, the following shapes are supported: // NULL, (B, 1), (1, 1) @@ -118,10 +118,10 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte const bias = inputs[2]; const maskIndex = inputs[3]; const past = inputs[4]; - const relativePositionBias = inputs[5]; + const attentionBias = inputs[5]; - if (past && relativePositionBias) { - throw new Error('Attention cannot have both past and relative_position_bias'); + if (past && attentionBias) { + throw new Error('Attention cannot have both past and attention_bias'); } if (input.dims.length !== 3) { @@ -217,6 +217,22 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte throw new Error('past is not supported'); } + if (attentionBias) { + if (attentionBias.dims.length !== 4) { + throw new Error('Input "attention_bias" must have 4 dimensions'); + } + + // TODO: support broadcasting the first and second dimensions of attention_bias + if ( + attentionBias.dims[0] !== batchSize || + attentionBias.dims[1] !== attributes.numHeads || + attentionBias.dims[2] !== sequenceLength || + attentionBias.dims[3] !== totalSequenceLength + ) { + throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)'); + } + } + return { batchSize, sequenceLength, @@ -348,7 +364,7 @@ const createAttentionProbsProgramInfo = ( q: TensorView, key: TensorView, pastKey: TensorView | undefined, - relativePositionBias: TensorView | undefined, + attentionBias: TensorView | undefined, parameters: AttentionParameters, attributes: AttentionAttrs, pastSequenceLength: number, @@ -385,7 +401,7 @@ const createAttentionProbsProgramInfo = ( if (pastKey) { inputDependencies.push('type'); } - if (relativePositionBias) { + if (attentionBias) { inputDependencies.push('type'); } const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }]; @@ -400,8 +416,8 @@ const createAttentionProbsProgramInfo = ( const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components); inputVars.push(pastKeyInput); } - if (relativePositionBias) { - inputVars.push(inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims)); + if (attentionBias) { + inputVars.push(inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims)); } const output = outputVariable('output', q.dataType, probsShape); const outputVars = [output]; @@ -491,7 +507,7 @@ const createAttentionProbsProgramInfo = ( } })()}; output[outputIdx] = ${output.type.value} (sum * uniforms.alpha) + ${ - relativePositionBias ? 'relative_position_bias[outputIdx]' : '0.0' + attentionBias ? 'attention_bias[outputIdx]' : '0.0' }; } }`; @@ -499,7 +515,7 @@ const createAttentionProbsProgramInfo = ( return { name: 'AttentionProbs', shaderCache: { - hint: `${components};${relativePositionBias !== undefined};${pastKey !== undefined};${context.outputCount}`, + hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${context.outputCount}`, inputDependencies, }, getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }), @@ -648,7 +664,7 @@ export const applyAttention = ( _past: TensorView | undefined, pastKey: TensorView | undefined, pastValue: TensorView | undefined, - relativePositionBias: TensorView | undefined, + attentionBias: TensorView | undefined, parameters: AttentionParameters, attributes: AttentionAttrs, ) => { @@ -657,8 +673,8 @@ export const applyAttention = ( const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const inputsK = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey ? [q, k, pastKey] : [q, k]; - if (relativePositionBias) { - inputsK.push(relativePositionBias); + if (attentionBias) { + inputsK.push(attentionBias); } // Run AttentionProbs @@ -668,7 +684,7 @@ export const applyAttention = ( q, k, outputCount > 1 ? pastKey : undefined, - relativePositionBias, + attentionBias, parameters, attributes, pastSequenceLength, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 1e0902eb0ff56..72e09303ba76f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -26,53 +26,60 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr const value = getInput(inputs, 2); const bias = getInput(inputs, 3); const keyPaddingMask = getInput(inputs, 4); - const relativePositionBias = getInput(inputs, 5); + const attentionBias = getInput(inputs, 5); const pastKey = getInput(inputs, 6); const pastValue = getInput(inputs, 7); - // Abbreviation and Meanings: - // B: batch_size - // S: sequence_length (input sequence length of query) - // P: past_sequence_length (past sequence length of key or value) - // L: kv_sequence_length (input sequence length of key or value) - // M: max_sequence_length - // T: total_sequence_length = past_sequence_length + kv_sequence_length - // N: num_heads - // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size - // H_v: v_head_size - // D_i: input hidden size - // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size - // D_v: v_hidden_size = num_heads * v_head_size - - // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None - // relative_position_bias : (B, 1, S, L) - // past_key : (B, N, S*, H) - // past_value : (B, N, S*, H) - // When no packing for q/k/v: + // --------------------------------------------------------------- + // Notations: + // B: batch_size + // N: num_heads + // H: head_size of Q and K + // H_v: head_size of V + // D: hidden_size for Q and K, where D = N * H + // D_v: hidden_size of V, where D_v = N * H_v + // S: q_sequence_length + // P: past_sequence_length of kv cache + // L: kv_sequence_length + // T: total_sequence_length = P + L + // M: max_sequence_length of kv cache when past and present share buffer + // --------------------------------------------------------------- + // MultiHeadAttention inputs: + // --------------------------------------------------------------- + // Q_K_V_BSNH - no packing: // query (Q) : (B, S, D) - // key (K) : (B, L, D) or (B, N, S*, H) - // value (V) : (B, L, D_v) or (B, N, S*, H) - // bias (Q/K/V) : (D + D + D_v) - // When packed kv is used: + // key (K) : (B, L, D) + // value (V) : (B, L, D_v) + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v): // query (Q) : (B, S, D) - // key (K) : (B, L, N, 2, H) - // value (V) : None - // bias (Q/K/V) : None - // When packed qkv is used: - // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) - // key (K) : None - // value (V) : None + // key (K) : (B, N, L, H) + // value (V) : (B, N, L, H_v) + // Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv): + // query (Q) : (B, S, D) + // key (K/V) : (B, L, N, 2, H) + // value : None + // QKV_BSN3H - packed qkv (kv cache is not used, S == L, D == D_v): + // query (Q/K/V) : (B, S, N, 3, H) + // key : None + // value : None + // + // Other inputs: // bias (Q/K/V) : None or (D + D + D_v) + // key_padding_mask (K/V) : (B) or (3 * B + 2) or (B, T) or (B, S, T) + // attention_bias : None or (B, N, S, T), (1, N, S, T), (B, 1, S, T) or (1, 1, S, T) + // past_key : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // past_value : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // + // Not Supported: + // key_padding_mask, packed kv, packed qkv, and broadcast for attention_bias. if (query.dims.length !== 3 && query.dims.length !== 5) { throw new Error('Input query is expected to have 3 or 5 dimensions'); } - const dmmhaPacking = false; const batchSize = query.dims[0]; const sequenceLength = query.dims[1]; - const hiddenSize = - query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4]; + const hiddenSize = query.dims.length === 3 ? query.dims[2] : attributes.numHeads * query.dims[4]; let kvSequenceLength = sequenceLength; let pastSequenceLength = 0; @@ -137,15 +144,15 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); } - qkvFormat = AttentionQkvFormat.unknown; + qkvFormat = AttentionQkvFormat.unknown; // Q_K_V_BSNH_BNSH_BNSH kvSequenceLength = key.dims[2]; } } else { // packed QKV - if (query.dims.length !== 3 && query.dims.length !== 5) { - throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); + if (query.dims.length !== 5) { + throw new Error('Input "query" is expected to have 5 dimensions when key is empty'); } - if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) { + if (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3) { throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); } @@ -157,13 +164,15 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr throw new Error('Input "bias" is expected to have 1 dimension'); } - if (value) { - if (query.dims.length === 5 && query.dims[3] === 2) { + if (key) { + if (key.dims.length === 5 && key.dims[3] === 2) { throw new Error('bias is not allowed for packed kv.'); } } } + const totalSequenceLength = pastSequenceLength + kvSequenceLength; + let maskType: AttentionMaskType = AttentionMaskType.none; if (keyPaddingMask) { maskType = AttentionMaskType.maskUnknown; @@ -174,11 +183,11 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr } else if (maskDims[0] === 3 * batchSize + 2) { maskType = AttentionMaskType.mask1DKeySeqLenStart; } - } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) { + } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === totalSequenceLength) { maskType = AttentionMaskType.mask2dKeyPadding; } if (maskType === AttentionMaskType.maskUnknown) { - throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, kv_sequence_length)'); + throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, total_sequence_length)'); } throw new Error('Mask not supported'); } @@ -200,32 +209,34 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr } vHiddenSize = value.dims[2]; } else { + // Q_K_V_BSNH_BNSH_BNSH if (kvSequenceLength !== value.dims[2]) { - throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)'); + throw new Error('Input "key" and "value" shall have the same dim 2 (kv_sequence_length)'); } vHiddenSize = value.dims[1] * value.dims[3]; passPastInKv = true; } } - const totalSequenceLength = pastSequenceLength + kvSequenceLength; const broadcastResPosBias = false; if (keyPaddingMask) { throw new Error('Key padding mask is not supported'); } - if (relativePositionBias) { - if (relativePositionBias.dims.length !== 4) { - throw new Error('Input "relative_position_bias" is expected to have 4 dimensions'); + if (attentionBias) { + if (attentionBias.dims.length !== 4) { + throw new Error('Input "attention_bias" is expected to have 4 dimensions'); } + + // TODO: support broadcasting the first and second dimensions of attention_bias. if ( - (relativePositionBias.dims[0] !== batchSize && relativePositionBias.dims[0] !== 1) || - relativePositionBias.dims[1] !== attributes.numHeads || - relativePositionBias.dims[2] !== sequenceLength || - relativePositionBias.dims[3] !== totalSequenceLength + attentionBias.dims[0] !== batchSize || + attentionBias.dims[1] !== attributes.numHeads || + attentionBias.dims[2] !== sequenceLength || + attentionBias.dims[3] !== totalSequenceLength ) { - throw new Error('Input "relative_position_bias" shape (batch_size, 1, sequence_length, kv_sequence_length)'); + throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)'); } } @@ -360,7 +371,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio const value = getInput(context.inputs, 2); const bias = getInput(context.inputs, 3); const keyPaddingMask = getInput(context.inputs, 4); - const relativePositionBias = getInput(context.inputs, 5); + const attentionBias = getInput(context.inputs, 5); const pastKey = getInput(context.inputs, 6); const pastValue = getInput(context.inputs, 7); if (query.dims.length === 5) { @@ -395,7 +406,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio undefined, pastKey, pastValue, - relativePositionBias, + attentionBias, params, attributes, ); @@ -425,17 +436,5 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio 2 * params.hiddenSize, ); - applyAttention( - context, - Q, - K, - V, - keyPaddingMask, - undefined, - pastKey, - pastValue, - relativePositionBias, - params, - attributes, - ); + applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes); }; diff --git a/js/web/test/data/ops/multihead-attention.jsonc b/js/web/test/data/ops/multihead-attention.jsonc index 6ce6a5e0a8ce6..ed937a22c0b84 100644 --- a/js/web/test/data/ops/multihead-attention.jsonc +++ b/js/web/test/data/ops/multihead-attention.jsonc @@ -228,7 +228,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -293,7 +293,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -322,7 +322,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=1 with optional RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs", + "name": "MultiHeadAttention Basic, one head and head-size=1 with optional AttentionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -358,7 +358,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -397,7 +397,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs", + "name": "MultiHeadAttention Basic, one head and head-size=4 with attentionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -433,7 +433,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -474,7 +474,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=1 with relativePositionBias, pastKey and pastValue", + "name": "MultiHeadAttention Basic, one head and head-size=1 with attentionBias, pastKey and pastValue", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -510,7 +510,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [10, 20], "dims": [1, 1, 1, 2], @@ -540,7 +540,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=4 with relativePositionBias, and pastValue", + "name": "MultiHeadAttention Basic, one head and head-size=4 with attentionBias, and pastValue", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -576,7 +576,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [100, 200], "dims": [1, 1, 1, 2], @@ -642,7 +642,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -717,7 +717,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -767,7 +767,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size one with RelativePositionBias, pastKey, pastValue, presentKey and presentValue", + "name": "MultiHeadAttention Basic, one head and head-size one with attentionBias, pastKey, pastValue, presentKey and presentValue", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -803,7 +803,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [10, 20], "dims": [1, 1, 1, 2], @@ -843,7 +843,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue inputs and PresentKey and PresentValue outputs", + "name": "MultiHeadAttention Basic, one head and head-size=4 with attentionBias, PastKey, PastValue inputs and PresentKey and PresentValue outputs", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -879,7 +879,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [100, 200], "dims": [1, 1, 1, 2], @@ -957,7 +957,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [10, 20], "dims": [1, 1, 1, 2], @@ -1033,7 +1033,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [50, 100], "dims": [1, 1, 1, 2], diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 768676259aa14..ad14fb8258656 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -198,7 +198,7 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_); @@ -208,7 +208,7 @@ Status Attention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past, - relative_position_bias, + attention_bias, ¶meters)); if (parameters.do_rotary) { @@ -338,7 +338,7 @@ Status Attention::Compute(OpKernelContext* context) const { output, nullptr /* present_key */, nullptr /* present_value */, batch_size, sequence_length, sequence_length, parameters.head_size, parameters.v_head_size, parameters.v_hidden_size, - relative_position_bias, context); + attention_bias, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index f7d8fedc734e4..52dcb990ab67f 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "core/providers/common.h" namespace onnxruntime { @@ -12,7 +13,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const Tensor* past_seq_len) const { // Abbreviation and Meanings: @@ -37,7 +38,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // bias (Q/K/V) : (D + D + D_v) // mask_index : see below // past (K/V) : (2, B, N, P, H) or NULL - // relative_position_bias : (B, N, S, T) or NULL + // attention_bias : (B or 1, N or 1, S, T) or NULL // For mask_index, the following shapes are supported: // NULL, (B, 1), (1, 1) @@ -49,9 +50,9 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger // than hidden dimension of Q, K and V. - if (past != nullptr && relative_position_bias != nullptr) { - // past is used on GPT-2 model with past state, we don't have a case for relative position bias yet - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and relative_position_bias"); + if (past != nullptr && attention_bias != nullptr) { + // past is used on GPT-2 model with past state, we don't have a case for attention bias yet + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and attention_bias"); } const auto& dims = input_shape.GetDims(); @@ -191,39 +192,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, } } - bool broadcast_res_pos_bias = false; - if (relative_position_bias != nullptr) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); + gsl::span attention_bias_dims; + if (attention_bias != nullptr) { + attention_bias_dims = attention_bias->Shape().GetDims(); - if (relative_position_bias_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); - } - - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be same as batch_size or 1, got ", - relative_position_bias_dims[0]); - } - if (relative_position_bias_dims[0] == 1) { - broadcast_res_pos_bias = true; - } - if (relative_position_bias_dims[1] != num_heads_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); - } - if (relative_position_bias_dims[2] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); - } - if (relative_position_bias_dims[3] != total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", - relative_position_bias_dims[3]); - } + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckAttentionBias( + attention_bias_dims, batch_size, num_heads_, sequence_length, total_sequence_length)); } if (past != nullptr && past_present_share_buffer_) { @@ -257,7 +231,8 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->mask_filter_value = mask_filter_value_; output_parameters->scale = scale_; output_parameters->mask_type = mask_type; - output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; + output_parameters->broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1; + output_parameters->broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1; output_parameters->qkv_format = Q_K_V_BNSH; } @@ -329,7 +304,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) const { @@ -337,7 +312,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, relative_position_bias, parameters, past_seq_len); + return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, attention_bias, parameters, past_seq_len); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index a6782daa58f1a..05756cd54d842 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -18,7 +18,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const int max_threads_per_block, // for CUDA const Tensor* past_seq_len = nullptr) const; @@ -63,7 +63,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const Tensor* past_seq_len = nullptr) const; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 88127387d08ea..5a5899166f5ba 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once +#include namespace onnxruntime { namespace contrib { @@ -68,7 +69,8 @@ struct AttentionParameters { bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; - bool broadcast_res_pos_bias; + bool broadcast_attn_bias_dim_0; + bool broadcast_attn_bias_dim_1; float mask_filter_value; float scale; bool use_tf32; @@ -88,8 +90,8 @@ struct PackedAttentionParameters { int num_heads; float scale; int token_count; - bool has_relative_position_bias; - bool broadcast_res_pos_bias; + bool broadcast_attn_bias_dim_0; + bool broadcast_attn_bias_dim_1; bool use_tf32; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index dd52001c2ac6b..ae2eaf0204026 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -19,23 +19,23 @@ class AttentionCPUBase : public AttentionBase { : AttentionBase(info, require_same_hidden_size) {} template - Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH - const T* K, // K data with shape BxNxLxH - const T* V, // V value with size BxNxLxH_v - const Tensor* mask_index, // mask index. nullptr if no mask or its size is B - const Tensor* past, // past state - const Tensor* past_key, // past K input tensor (if not using past state) - const Tensor* past_value, // past V input tensor (if not using past state) - Tensor* output, // output tensor - Tensor* present_key, // present K output tensor (if separating present KV) - Tensor* present_value, // present V output tensor (if separating present KV) - int batch_size, // batch size (B) - int sequence_length, // sequence length of Q (S) - int kv_sequence_length, // sequence length of K or V (L) - int qk_head_size, // head size of Q or K (H) - int v_head_size, // head size of V (H_v) - int v_hidden_size, // hidden size of V (D_v) - const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT + Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxLxH + const T* V, // V value with size BxNxLxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past, // past state + const Tensor* past_key, // past K input tensor (if not using past state) + const Tensor* past_value, // past V input tensor (if not using past state) + Tensor* output, // output tensor + Tensor* present_key, // present K output tensor (if separating present KV) + Tensor* present_value, // present V output tensor (if separating present KV) + int batch_size, // batch size (B) + int sequence_length, // sequence length of Q (S) + int kv_sequence_length, // sequence length of K or V (L) + int qk_head_size, // head size of Q or K (H) + int v_head_size, // head size of V (H_v) + int v_hidden_size, // hidden size of V (D_v) + const Tensor* attn_bias, // additive bias applied on scaled QK. OpKernelContext* context) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -66,10 +66,14 @@ class AttentionCPUBase : public AttentionBase { gsl::span mask_index_dims = mask_index != nullptr ? mask_index->Shape().GetDims() : gsl::span{}; + DUMP_CPU_TENSOR_INIT(); + DUMP_CPU_TENSOR("Mask", mask_index_data, mask_index_dims); + if (mask_data != nullptr) { + // Convert mask from boolean (0/1) to float (mask_filter_value/0.0f). + // Merge padding mask with causual mask, and broadcast to 3D (BxSxT). PrepareMask(mask_index_data, mask_index_dims, static_cast(mask_data), causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_); - DUMP_CPU_TENSOR_INIT(); DUMP_CPU_TENSOR("Mask3D", static_cast(mask_data), batch_size, sequence_length, total_sequence_length); } @@ -82,10 +86,8 @@ class AttentionCPUBase : public AttentionBase { const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; - const T* relative_position_bias_data = nullptr; - if (relative_position_bias != nullptr) { - relative_position_bias_data = relative_position_bias->Data(); - } + const T* attn_bias_data = (attn_bias != nullptr) ? attn_bias->Data() : nullptr; + auto attn_bias_dims = (attn_bias != nullptr) ? attn_bias->Shape().GetDims() : gsl::span{}; // Compute the attention score. size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T); @@ -95,7 +97,7 @@ class AttentionCPUBase : public AttentionBase { static_cast(mask_data), batch_size, sequence_length, kv_sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, - present_data, present_key_data, tp, scale, relative_position_bias_data); + present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_dims); // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) auto out_tmp_data = @@ -115,22 +117,23 @@ class AttentionCPUBase : public AttentionBase { // 1 x mask_data(B, N, S, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - T* mask_data, // buffer for mask data. - int batch_size, // batch size of self-attention - int sequence_length, // sequence length of self-attention (S) - int kv_sequence_length, // sequence length of cross-attention (L) - int past_sequence_length, // sequence length of past state - int head_size, // head size of self-attention - const T* past, // past state - const T* past_key, // past key only (if not using past state) - T* present, // present state - T* present_key, // present key only (if not using present state) - ThreadPool* tp, // thread pool - float scale, // scale factor - const T* relative_position_bias_data // bias addition matrix with shape BxNxSxT + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + T* mask_data, // buffer for mask data. + int batch_size, // batch size of self-attention + int sequence_length, // sequence length of self-attention (S) + int kv_sequence_length, // sequence length of cross-attention (L) + int past_sequence_length, // sequence length of past state + int head_size, // head size of self-attention + const T* past, // past state + const T* past_key, // past key only (if not using past state) + T* present, // present state + T* present_key, // present key only (if not using present state) + ThreadPool* tp, // thread pool + float scale, // scale factor + const T* attn_bias_data, // attention bias + gsl::span attn_bias_dims // attention bias shape ) const { const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // P x H @@ -138,14 +141,20 @@ class AttentionCPUBase : public AttentionBase { const size_t kv_input_chunk_length = static_cast(kv_sequence_length) * head_size; // L x H const size_t present_chunk_length = past_chunk_length + kv_input_chunk_length; // T x H + DUMP_CPU_TENSOR_INIT(); + DUMP_CPU_TENSOR("Q", Q, batch_size, num_heads_, sequence_length, head_size); + DUMP_CPU_TENSOR("K", K, batch_size, num_heads_, total_sequence_length, head_size); + DUMP_CPU_TENSOR("Attn_Bias", attn_bias_data, attn_bias_dims); + { const int loop_len = batch_size * num_heads_; const float alpha = scale; TensorOpCost unit_cost; - const ptrdiff_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T); + const ptrdiff_t probs_matrix_size = SafeInt(sequence_length) * total_sequence_length; + const ptrdiff_t probs_matrix_bytes = probs_matrix_size * sizeof(T); unit_cost.compute_cycles = - static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length); + static_cast(SafeInt(2) * head_size * probs_matrix_size); unit_cost.bytes_loaded = static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T)); unit_cost.bytes_stored = static_cast(probs_matrix_bytes); @@ -160,8 +169,8 @@ class AttentionCPUBase : public AttentionBase { unit_cost.bytes_stored += bytes_to_copy_key; } - if (relative_position_bias_data != nullptr) { - unit_cost.compute_cycles += static_cast(sequence_length * total_sequence_length); + if (attn_bias_data != nullptr) { + unit_cost.compute_cycles += static_cast(probs_matrix_size); unit_cost.bytes_loaded += probs_matrix_bytes * 2; unit_cost.bytes_stored += probs_matrix_bytes; } @@ -169,13 +178,34 @@ class AttentionCPUBase : public AttentionBase { ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast(i) / num_heads_; + const std::ptrdiff_t head_index = i % static_cast(num_heads_); + + const ptrdiff_t output_offset = SafeInt(i) * probs_matrix_size; + const ptrdiff_t mask_offset = SafeInt(batch_index) * probs_matrix_size; - const ptrdiff_t output_offset = SafeInt(i) * sequence_length * total_sequence_length; - const ptrdiff_t mask_offset = SafeInt(batch_index) * sequence_length * total_sequence_length; T* output = attention_probs + output_offset; - // Broadcast mask data: (Bx)SxT -> (BxNx)SxT - if (mask_data != nullptr) { + if (attn_bias_data != nullptr) { + // Attention bias has shape (B or 1, N or 1, S, T) + // Here we handle the broadcast of batch_size and num_heads dimensions. + ptrdiff_t attn_bias_offset = 0; + if (attn_bias_dims[0] != 1) { + attn_bias_offset += SafeInt(batch_index) * num_heads_ * probs_matrix_size; + } + if (attn_bias_dims[1] != 1) { + attn_bias_offset += head_index * probs_matrix_size; + } + + memcpy(output, attn_bias_data + attn_bias_offset, probs_matrix_bytes); + + if (mask_data != nullptr) { + // This can be optimized with vectorized add using MlasAddFloat32x4. + for (ptrdiff_t j = 0; j < probs_matrix_size; j++) { + output[j] += mask_data[mask_offset + j]; + } + } + } else if (mask_data != nullptr) { + // Broadcast mask data: (Bx)SxT -> (BxNx)SxT memcpy(output, mask_data + mask_offset, probs_matrix_bytes); } @@ -193,20 +223,13 @@ class AttentionCPUBase : public AttentionBase { // B: K' (B x N x) T x H (B x N x) H x T H x T // C: attention_probs (B x N x) S x T (B x N x) S x T S x T math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha, - Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, output, - nullptr); - - if (relative_position_bias_data != nullptr) { - for (int j = 0; j < sequence_length * total_sequence_length; j++) { - output[j] += relative_position_bias_data[output_offset + j]; - } - } + Q + q_input_chunk_length * i, k, + (mask_data != nullptr || attn_bias_data != nullptr) ? 1.0f : 0.0f, + output, nullptr); } }); } - DUMP_CPU_TENSOR_INIT(); - DUMP_CPU_TENSOR("Q", Q, batch_size, num_heads_, sequence_length, head_size); DUMP_CPU_TENSOR("QK (scaled)", attention_probs, batch_size, num_heads_, sequence_length, total_sequence_length); // attention_probs(B, N, S, T) = Softmax(attention_probs) diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 0d77376779230..ca818f09c4b1e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -57,7 +57,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { const Tensor* value = context->Input(2); const Tensor* bias = context->Input(3); const Tensor* key_padding_mask = context->Input(4); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* attn_bias = context->Input(5); const Tensor* past_key = context->Input(6); const Tensor* past_value = context->Input(7); @@ -75,7 +75,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { value, bias, key_padding_mask, - extra_add_qk, + attn_bias, past_key, past_value, nullptr, @@ -135,7 +135,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { value->Data(), key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v, batch_size, q_sequence_length, kv_sequence_length, - qk_head_size, v_head_size, v_hidden_size, extra_add_qk, context); + qk_head_size, v_head_size, v_hidden_size, attn_bias, context); } OrtValue K; @@ -149,7 +149,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { !disable_flash_ && !is_unidirectional_ && key_padding_mask == nullptr && - extra_add_qk == nullptr && + attn_bias == nullptr && past_key == nullptr && past_value == nullptr && present_k == nullptr && @@ -215,7 +215,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { V.GetMutable()->MutableData(), key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v, batch_size, q_sequence_length, kv_sequence_length, - qk_head_size, v_head_size, v_hidden_size, extra_add_qk, context); + qk_head_size, v_head_size, v_hidden_size, attn_bias, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index cfb8d36843777..0cfe90963c334 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -179,39 +179,35 @@ Status CheckPast(const T* past_key, const T* past_value, const T* past_seq_len, return Status::OK(); } -template -Status CheckRelativePositionBias( - const T* relative_position_bias, int batch_size, int num_heads, int sequence_length, int total_sequence_length, - bool& broadcast_res_pos_bias) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - - if (relative_position_bias_dims.size() != 4) { +inline Status CheckAttentionBias( + const gsl::span& attention_bias_dims, + int64_t batch_size, int64_t num_heads, int64_t sequence_length, int64_t total_sequence_length) { + if (attention_bias_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); + "Input 'attention_bias' is expected to have 4 dimensions, got ", + attention_bias_dims.size()); } - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { + + if (attention_bias_dims[0] != batch_size && attention_bias_dims[0] != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", - relative_position_bias_dims[0]); + "Input 'attention_bias' dimension 0 should be batch_size or 1, got ", + attention_bias_dims[0]); } - if (relative_position_bias_dims[0] == 1) { - broadcast_res_pos_bias = true; - } - if (relative_position_bias_dims[1] != num_heads) { + + if (attention_bias_dims[1] != num_heads && attention_bias_dims[1] != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); + "Input 'attention_bias' dimension 1 should be same as number of heads or 1, got ", + attention_bias_dims[1]); } - if (relative_position_bias_dims[2] != sequence_length) { + if (attention_bias_dims[2] != sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); + "Input 'attention_bias' dimension 2 should be same as sequence_length, got ", + attention_bias_dims[2]); } - if (relative_position_bias_dims[3] != total_sequence_length) { + if (attention_bias_dims[3] != total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", - relative_position_bias_dims[3]); + "Input 'attention_bias' dimension 3 should be same as total_sequence_length, got ", + attention_bias_dims[3]); } return Status::OK(); } @@ -243,7 +239,7 @@ Status CheckInputs(const T* query, const T* value, const T* bias, const T* key_padding_mask, - const T* relative_position_bias, + const T* attention_bias, const T* past_key, const T* past_value, const T* past_seq_len, @@ -258,13 +254,15 @@ Status CheckInputs(const T* query, // Notations: // B: batch_size // N: num_heads - // H: head_size (V might have different head size than Q and K) - // D: hidden_size = N * H + // H: head_size of Q and K. + // H_v: head_size of V. + // D: hidden_size of Q and K, where D = N * H + // D_v: hidden_size of V, where D_v = N * H_v // S: q_sequence_length - // P: past_sequence_length + // P: past_sequence_length of kv cache // L: kv_sequence_length // T: total_sequence_length = P + L - // M: max_sequence_length + // M: max_sequence_length of kv cache when past and present share buffer // --------------------------------------------------------------- // MultiHeadAttention inputs: // --------------------------------------------------------------- @@ -275,7 +273,7 @@ Status CheckInputs(const T* query, // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v): // query (Q) : (B, S, D) // key (K) : (B, N, L, H) - // value (V) : (B, N, L, H) + // value (V) : (B, N, L, H_v) // Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv): // query (Q) : (B, S, D) // key (K/V) : (B, L, N, 2, H) @@ -288,7 +286,7 @@ Status CheckInputs(const T* query, // Other inputs: // bias (Q/K/V) : None or (D + D + D_v) // key_padding_mask (K/V) : (B) or (3 * B + 2) or (B, T) or (B, S, T) - // relative_position_bias : (B, N, S, T) or (1, N, S, T) + // attention_bias : (B, N, S, T), (1, N, S, T), (B, 1, S, T) or (1, 1, S, T) // past_key : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. // past_value : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. // --------------------------------------------------------------- @@ -298,7 +296,7 @@ Status CheckInputs(const T* query, // query (Q) : (B, S, D) // key (K) : (B, L, D) // value (V) : (B, L, D) - // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache and relative_position_bias are not used. L == T): + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache and attention_bias are not used. L == T): // query (Q) : (B, S, D) // key (K) : (B, N, L, H) // value (V) : (B, N, L, H) @@ -310,7 +308,7 @@ Status CheckInputs(const T* query, // Other inputs: // bias (Q/K/V) : None or (3 * D) // key_padding_mask (K/V) : None or (B, T) - // relative_position_bias : (1, N, S, T), or (B, N, S, T) where only 1 x N x S x T data is used in CUDA. + // attention_bias : (1, N, S, T), or (B, N, S, T) where only 1 x N x S x T data is used in CUDA. // // The following inputs are not used in cross attention (so they are None for cross attention): // past_key : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. @@ -401,10 +399,11 @@ Status CheckInputs(const T* query, } } - bool broadcast_res_pos_bias = false; - if (relative_position_bias != nullptr) { - ORT_RETURN_IF_ERROR(CheckRelativePositionBias( - relative_position_bias, batch_size, num_heads, sequence_length, total_sequence_length, broadcast_res_pos_bias)); + gsl::span attention_bias_dims; + if (attention_bias != nullptr) { + attention_bias_dims = attention_bias->Shape().GetDims(); + ORT_RETURN_IF_ERROR(CheckAttentionBias( + attention_bias_dims, batch_size, num_heads, sequence_length, total_sequence_length)); } assert(qkv_format != UNKNOWN); @@ -428,7 +427,8 @@ Status CheckInputs(const T* query, output_parameters->mask_filter_value = mask_filter_value; output_parameters->mask_type = mask_type; output_parameters->scale = scale; - output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; + output_parameters->broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1; + output_parameters->broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1; output_parameters->qkv_format = qkv_format; } @@ -441,7 +441,7 @@ Status CheckInputs(const T* query, const T* value, const T* bias, const T* key_padding_mask, - const T* relative_position_bias, + const T* attention_bias, const T* past_key, const T* past_value, const T* past_seq_len, @@ -457,7 +457,7 @@ Status CheckInputs(const T* query, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, + return CheckInputs(query, key, value, bias, key_padding_mask, attention_bias, past_key, past_value, past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, past_present_share_buffer, operator_type); } diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 6201b892a89b0..2c897f183164f 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -160,7 +160,7 @@ Status QAttention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past_tensor, - nullptr, // relative_position_bias + nullptr, // attention_bias nullptr // parameters )); diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h index 2782a59d4326d..12cbc5049a02a 100644 --- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h +++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h @@ -32,6 +32,11 @@ class IConsoleDumper { virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; + virtual void Print(const char* name, const int32_t* tensor, gsl::span& dims) const = 0; + virtual void Print(const char* name, const int64_t* tensor, gsl::span& dims) const = 0; + virtual void Print(const char* name, const float* tensor, gsl::span& dims) const = 0; + virtual void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const = 0; + virtual void Print(const char* name, const Tensor& value) const = 0; virtual void Print(const char* name, const OrtValue& value) const = 0; virtual void Print(const char* name, int index, bool end_line) const = 0; @@ -43,5 +48,38 @@ class IConsoleDumper { bool is_enabled_; }; +template +void PrintTensorByDims(const TConsoleDumper* dumper, + const char* name, + const T* tensor, + gsl::span& dims) { + if (dumper->IsEnabled() && (tensor == nullptr || dims.size() == 0)) { + std::cout << std::string(name) << " is None" << std::endl; + return; + } + + auto num_dims = dims.size(); + if (num_dims == 1) { + dumper->Print(name, tensor, 1, static_cast(dims[0])); + } else if (num_dims == 2) { + dumper->Print(name, tensor, static_cast(dims[0]), static_cast(dims[1])); + } else if (num_dims == 3) { + dumper->Print(name, tensor, static_cast(dims[0]), static_cast(dims[1]), static_cast(dims[2])); + } else if (num_dims == 4) { + dumper->Print(name, tensor, + static_cast(dims[0]), + static_cast(dims[1]), + static_cast(dims[2]), + static_cast(dims[3])); + } else if (num_dims == 5) { + dumper->Print(name, tensor, + static_cast(dims[0]) * static_cast(dims[1]), + static_cast(dims[2]), + static_cast(dims[3]), + static_cast(dims[4])); + } else { + ORT_ENFORCE(false, "Unsupported tensor dims"); + } +} } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc index 87a9cd3965763..7755f9505d99d 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc @@ -246,7 +246,24 @@ void CpuTensorConsoleDumper::Print(const char* name, const std::string& value, b } } +void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + #else + CpuTensorConsoleDumper::CpuTensorConsoleDumper() { } @@ -303,6 +320,18 @@ void CpuTensorConsoleDumper::Print(const char*, int, bool) const { void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const { } + +void CpuTensorConsoleDumper::Print(const char*, const int32_t*, gsl::span&) const { +} + +void CpuTensorConsoleDumper::Print(const char*, const int64_t*, gsl::span&) const { +} + +void CpuTensorConsoleDumper::Print(const char*, const float*, gsl::span&) const { +} + +void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, gsl::span&) const { +} #endif } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h index f102eae6ec709..6fc4dfd4a0671 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h @@ -30,6 +30,11 @@ class CpuTensorConsoleDumper : public IConsoleDumper { void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; + void Print(const char* name, const int32_t* tensor, gsl::span& dims) const override; + void Print(const char* name, const int64_t* tensor, gsl::span& dims) const override; + void Print(const char* name, const float* tensor, gsl::span& dims) const override; + void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const override; + void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; void Print(const char* name, int index, bool end_line) const override; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 5c0989bced70c..1d1416995a673 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -59,7 +59,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(kPastInputIndex); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); auto& device_prop = GetDeviceProp(); @@ -74,7 +74,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bias != nullptr ? bias->Shape() : bias_shape, mask_index, past, - relative_position_bias, + attention_bias, ¶meters, device_prop.maxThreadsPerBlock, past_seq_len)); @@ -104,7 +104,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && - (nullptr == relative_position_bias) && + (nullptr == attention_bias) && nullptr == past && nullptr == present && parameters.hidden_size == parameters.v_hidden_size && @@ -146,7 +146,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // where past state is empty. bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == relative_position_bias && + nullptr == attention_bias && parameters.past_sequence_length == 0 && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, @@ -169,7 +169,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { (nullptr == mask_index || is_mask_1d_seq_len) && nullptr == past && nullptr == present && - nullptr == relative_position_bias && + nullptr == attention_bias && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, enable_trt_flash_attention_, false); @@ -201,12 +201,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == present && (nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && + (nullptr == attention_bias || parameters.sequence_length % (4 * sizeof(T)) == 0) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); - if (use_memory_efficient_attention) { - bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; - use_memory_efficient_attention = (nullptr == relative_position_bias || is_good_for_rpb); - } #else constexpr bool use_memory_efficient_attention = false; #endif @@ -277,8 +274,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (nullptr != past) { data.past = reinterpret_cast(past->Data()); } - if (nullptr != relative_position_bias) { - data.relative_position_bias = reinterpret_cast(relative_position_bias->Data()); + if (nullptr != attention_bias) { + data.attention_bias = reinterpret_cast(attention_bias->Data()); } data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index f9eabe27d97e4..28e2b7b28764b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -290,7 +290,7 @@ Status FlashAttention( assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); assert(nullptr == data.mask_index); - assert(nullptr == data.relative_position_bias); + assert(nullptr == data.attention_bias); assert(parameters.head_size == parameters.v_head_size); constexpr bool is_bf16 = false; @@ -332,6 +332,8 @@ Status EfficientAttention( // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + assert(parameters.mask_type == AttentionMaskType::MASK_NONE || + parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -345,22 +347,25 @@ Status EfficientAttention( p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; p.scale = scale; - p.seqlen_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index)); - p.seqstart_q_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast( - data.mask_index + parameters.batch_size)); - p.seqstart_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast( - data.mask_index + 2 * parameters.batch_size + 1)); + + if (nullptr == data.mask_index) { + p.seqlen_k_ptr = nullptr; + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + } else { + p.seqlen_k_ptr = const_cast(reinterpret_cast(data.mask_index)); + p.seqstart_q_ptr = p.seqlen_k_ptr + parameters.batch_size; + p.seqstart_k_ptr = p.seqlen_k_ptr + 2 * parameters.batch_size + 1; + } + p.query = data.q; p.key = data.k; p.value = data.v; - p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; - p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; + + p.attn_bias = (nullptr == data.attention_bias) ? nullptr : data.attention_bias; + p.broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; + p.broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; + p.output = data.output; p.is_kv_bsnh = data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) @@ -415,6 +420,12 @@ Status UnfusedAttention( const int present_size_per_batch_k = present_sequence_length * qk_head_size; const int present_size_per_batch_v = present_sequence_length * v_head_size; + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k", data.k, batch_size, num_heads, total_sequence_length, qk_head_size); + DUMP_TENSOR_D("v", data.v, batch_size, num_heads, total_sequence_length, v_head_size); + DUMP_TENSOR_D("mask_index", mask_index, mask_index_dims); + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, total_sequence_length, sequence_length, qk_head_size, @@ -423,7 +434,6 @@ Status UnfusedAttention( &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop, parameters.use_tf32)); - DUMP_TENSOR_INIT(); DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); constexpr size_t element_size = sizeof(T); @@ -431,6 +441,9 @@ Status UnfusedAttention( sequence_length, total_sequence_length); T* scratch2 = data.scratch + (bytes / element_size); + const bool broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; + const bool broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; + // Apply softmax and store result R to scratch2: BxNxSxT if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask const int mask_dimension = static_cast(mask_index_dims.size()); @@ -444,7 +457,7 @@ Status UnfusedAttention( ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask( ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias, + mask_index, nullptr, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, parameters.mask_filter_value)); @@ -454,17 +467,17 @@ Status UnfusedAttention( const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, mask_start, data.relative_position_bias, parameters.broadcast_res_pos_bias, + mask_index, mask_start, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, data.scratch, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( ComputeSoftmax( - stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, - parameters.broadcast_res_pos_bias, data.scratch, scratch2, parameters.is_unidirectional)); + stream, total_sequence_length, sequence_length, batch_size, num_heads, + data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + data.scratch, scratch2, parameters.is_unidirectional)); } DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); - DUMP_TENSOR_D("V", data.v, batch_size, num_heads, sequence_length, v_head_size); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v T* temp_output = data.q; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index fad353dcfeb07..a6760f84e69f3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -69,7 +69,7 @@ struct AttentionData { const T* past = nullptr; const T* past_key = nullptr; const T* past_value = nullptr; - const T* relative_position_bias = nullptr; + const T* attention_bias = nullptr; bool has_qkv_workspace = false; T* workspace = nullptr; @@ -115,7 +115,7 @@ struct AttentionData { << ", fused_runner=" << (fused_runner != nullptr) << ", fused_cross=" << (fused_cross_attention_kernel != nullptr) << ", bias=" << (bias != nullptr) - << ", attn_bias=" << (relative_position_bias != nullptr) + << ", attn_bias=" << (attention_bias != nullptr) << ", mask_dims=" << mask_index_dims.size() << ", has_qkv_workspace=" << has_qkv_workspace << ", workspace=" << workspace_bytes diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 05c592ec61059..575e65ebef0e9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -77,18 +77,22 @@ void DumpInputs(contrib::AttentionParameters& parameters, AttentionData& data DUMP_TENSOR_D("V_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); } - if (data.relative_position_bias != nullptr) { - DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, - parameters.broadcast_res_pos_bias ? 1 : batch_size, - num_heads, sequence_length, kv_sequence_length); + if (data.attention_bias != nullptr) { + DUMP_TENSOR_D("attention_bias", data.attention_bias, + parameters.broadcast_attn_bias_dim_0 ? 1 : batch_size, + parameters.broadcast_attn_bias_dim_1 ? 1 : num_heads, + sequence_length, + parameters.total_sequence_length); } if (data.mask_index != nullptr) { if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { - DUMP_TENSOR_D("mask", data.mask_index, batch_size, parameters.total_sequence_length); + DUMP_TENSOR_D("mask (2D)", data.mask_index, batch_size, parameters.total_sequence_length); } if (parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { - DUMP_TENSOR_D("mask", data.mask_index, 3 * batch_size + 2, 1); + DUMP_TENSOR_D("mask (seqlen_k)", data.mask_index, 1, batch_size); + DUMP_TENSOR_D("mask (cu_seqlen_q)", data.mask_index + batch_size, 1, batch_size + 1); + DUMP_TENSOR_D("mask (cu_seqlen_k)", data.mask_index + 2 * batch_size + 1, 1, batch_size + 1); } } } @@ -258,7 +262,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, if (data.fused_cross_attention_kernel != nullptr) { assert(qk_head_size == v_head_size); - assert(data.relative_position_bias == nullptr); + assert(data.attention_bias == nullptr); assert(data.mask_index == nullptr); assert(parameters.hidden_size == parameters.v_hidden_size); @@ -290,7 +294,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, #endif else if (data.fused_runner != nullptr) { assert(qk_head_size == v_head_size); - assert(data.relative_position_bias == nullptr); + assert(data.attention_bias == nullptr); // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H) LaunchAddBiasTransposeTrt( @@ -524,7 +528,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, true, v_head_size, qkv_add_bias, 3); data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; } else if (nullptr != data.fused_runner) { - assert(nullptr == data.relative_position_bias); + assert(nullptr == data.attention_bias); if (data.bias == nullptr) { // When there is no bias, we can directly use the original packed QKV input. // Need revisit this when we add support for causal. diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index 01ea02f48d3ab..52f94247a8b2b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -29,12 +29,45 @@ namespace onnxruntime { namespace contrib { namespace attention_softmax_cuda { -template -__device__ inline void Softmax(const int all_sequence_length, +#define DISPATCH_BIAS(attn_bias, HAS_BIAS, ...) \ + [&] { \ + const dim3 grid(num_heads* sequence_length, batch_size, 1); \ + if (attn_bias != nullptr) { \ + constexpr static bool HAS_BIAS = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool HAS_BIAS = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// Macro to declare variables: +// offset: offset in input/output +// bias_offset: offset in attn_bias +// b: batch index +// s: sequence index +// grid size is (num_heads * sequence_length, batch_size, 1) +// input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) +// bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) +#define DECLARE_SOFTMAX_VARS() \ + [[maybe_unused]] const int s = blockIdx.x % sequence_length; \ + const int b = blockIdx.y; \ + int64_t offset = static_cast(b * gridDim.x + blockIdx.x) * static_cast(total_sequence_length); \ + [[maybe_unused]] int64_t bias_offset = 0; \ + if constexpr (HAS_BIAS) { \ + const int j = (broadcast_attn_bias_dim_0 ? 0 : (b * gridDim.x)) + (broadcast_attn_bias_dim_1 ? s : blockIdx.x); \ + bias_offset = static_cast(j) * static_cast(total_sequence_length); \ + } + +// This kernel is for non causal, attention mask 1D or None, and total_sequence_length > 1024. +template +__device__ inline void Softmax(const int total_sequence_length, + const int sequence_length, const int valid_end, const int valid_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output) { using BlockReduce = cub::BlockReduce; @@ -45,28 +78,22 @@ __device__ inline void Softmax(const int all_sequence_length, float thread_data_max(-CUDART_INF_F); - const bool no_rpb = (rel_pos_bias == nullptr); + DECLARE_SOFTMAX_VARS(); // e^x is represented as infinity if x is large enough, like 100.f. // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - const int size_per_batch = gridDim.x * all_sequence_length; for (int i = threadIdx.x; i < valid_end; i += TPB) { if (i >= valid_start) { - const int index = offset + i; - float input_at_idx = no_rpb - ? float(input[index]) - : float(input[index] + (broadcast_rel_pos_bias - ? rel_pos_bias[index % size_per_batch] - : rel_pos_bias[index])); - if (thread_data_max < input_at_idx) { - thread_data_max = input_at_idx; + float input_data = HAS_BIAS + ? float(input[offset + i]) + float(attn_bias[bias_offset + i]) + : float(input[offset + i]); + if (thread_data_max < input_data) { + thread_data_max = input_data; } } } - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max()); // Store max value @@ -78,9 +105,11 @@ __device__ inline void Softmax(const int all_sequence_length, float thread_data_sum(0.f); for (int i = threadIdx.x; i < valid_end; i += TPB) { if (i >= valid_start) { - const int index = offset + i; - float val = no_rpb ? input[index] : input[index] + rel_pos_bias[index % size_per_batch]; - thread_data_sum += expf(val - max_block); + float input_data = HAS_BIAS + ? float(input[offset + i]) + float(attn_bias[bias_offset + i]) + : float(input[offset + i]); + + thread_data_sum += expf(input_data - max_block); } } @@ -90,21 +119,25 @@ __device__ inline void Softmax(const int all_sequence_length, } __syncthreads(); - for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { const int index = offset + i; - float input_at_idx = no_rpb ? float(input[index]) : float(input[index] + rel_pos_bias[index % size_per_batch]); - const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; + float input_data = HAS_BIAS + ? float(input[index]) + float(attn_bias[bias_offset + i]) + : float(input[index]); + const float val = (i >= valid_start && i < valid_end) ? expf(input_data - max_block) * sum_reverse_block : 0.f; output[index] = T(val); } } -template -__device__ inline void SoftmaxSmall(const int all_sequence_length, +// This kernel is for non causal, attention mask 1D or None, and total_sequence_length <= 1024. +template +__device__ inline void SoftmaxSmall(const int total_sequence_length, const int sequence_length, const int valid_end, const int valid_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, bool causal) { @@ -114,34 +147,30 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; + DECLARE_SOFTMAX_VARS(); + const int index = offset + threadIdx.x; // Update end position for causal. int end = valid_end; if (causal) { - const int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; + const int end_causal = total_sequence_length - sequence_length + s + 1; if (end_causal < end) { end = end_causal; } } const bool is_valid = (threadIdx.x >= valid_start && threadIdx.x < end); + float input_data = is_valid ? (HAS_BIAS + ? float(input[index]) + float(attn_bias[bias_offset + threadIdx.x]) + : float(input[index])) + : float(-CUDART_INF_F); // e^x is represented as infinity if x is large enough, like 100.f. // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const bool no_rpb = (rel_pos_bias == nullptr); - const int size_per_batch = gridDim.x * all_sequence_length; - float input_data = no_rpb - ? float(input[index]) - : float(input[index] + (broadcast_rel_pos_bias - ? rel_pos_bias[index % size_per_batch] - : rel_pos_bias[index])); - float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F); - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end); + const auto max = BlockReduce(tmp_storage).Reduce(input_data, cub::Max(), end); // Store max value if (threadIdx.x == 0) { @@ -162,23 +191,25 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length, } __syncthreads(); - // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. - if (threadIdx.x < all_sequence_length) { + // threadIdx.x might be larger than total_sequence_length due to alignment to 32x. + if (threadIdx.x < total_sequence_length) { output[index] = is_valid ? T(thread_data_exp * sum_reverse_block) : T(0.f); } } -template -__global__ void SoftmaxLargeKernel(const int all_sequence_length, +// This kernel is for causal or not, attention mask 1D or None, and total_sequence_length <= 1024. +template +__global__ void SoftmaxLargeKernel(const int total_sequence_length, const int sequence_length, const int valid_end, const int valid_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, bool causal) { - extern __shared__ float cached_data[]; // float[all_sequence_length] + extern __shared__ float cached_data[]; // float[total_sequence_length] using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -186,36 +217,26 @@ __global__ void SoftmaxLargeKernel(const int all_sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; + DECLARE_SOFTMAX_VARS(); + // Update end position for causal. int end = valid_end; if (causal) { - int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; + int end_causal = total_sequence_length - sequence_length + s + 1; if (end_causal < end) { end = end_causal; } } - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - const int size_per_batch = gridDim.x * all_sequence_length; - float thread_data_max = -CUDART_INF_F; - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { - const int index = offset + seq_idx; - const bool is_valid = (seq_idx >= valid_start && seq_idx < end); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - float input_data = is_valid - ? (rel_pos_bias - ? float(input[index] + (broadcast_rel_pos_bias - ? rel_pos_bias[index % size_per_batch] - : rel_pos_bias[index])) - : float(input[index])) - : float(-CUDART_INF_F); - cached_data[seq_idx] = input_data; + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { + const int index = offset + i; + const bool is_valid = (i >= valid_start && i < end); + float input_data = is_valid ? (HAS_BIAS + ? float(input[index]) + float(attn_bias[bias_offset + i]) + : float(input[index])) + : float(-CUDART_INF_F); + cached_data[i] = input_data; thread_data_max = max(thread_data_max, input_data); } const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end); @@ -227,10 +248,10 @@ __global__ void SoftmaxLargeKernel(const int all_sequence_length, __syncthreads(); float thread_data_exp(0.f); - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { - const bool is_valid = (seq_idx >= valid_start && seq_idx < end); - cached_data[seq_idx] = is_valid ? expf(cached_data[seq_idx] - max_block) : 0.0f; - thread_data_exp += cached_data[seq_idx]; + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { + const bool is_valid = (i >= valid_start && i < end); + cached_data[i] = is_valid ? expf(cached_data[i] - max_block) : 0.0f; + thread_data_exp += cached_data[i]; } const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end); @@ -240,20 +261,22 @@ __global__ void SoftmaxLargeKernel(const int all_sequence_length, } __syncthreads(); - // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { - const bool is_valid = (seq_idx >= valid_start && seq_idx < end); - output[offset + seq_idx] = is_valid ? T(cached_data[seq_idx] * sum_reverse_block) : T(0.f); + // threadIdx.x might be larger than total_sequence_length due to alignment to 32x. + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { + const bool is_valid = (i >= valid_start && i < end); + output[offset + i] = is_valid ? T(cached_data[i] * sum_reverse_block) : T(0.f); } } -template -__global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, +// This kernel is for causal or not, raw attention mask (2D, 3D or 4D) and total_sequence_length > 1024. +template +__global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length, const int sequence_length, const int* attention_mask, // 2D, 3D or 4D attention mask const bool* key_padding_mask, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal, @@ -262,7 +285,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, const int max_sequence_length, const bool skip_softmax, const float mask_filter_value) { - extern __shared__ float cached_data[]; // float[all_sequence_length] + extern __shared__ float cached_data[]; // float[total_sequence_length] using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -271,37 +294,30 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, __shared__ float max_block; float max_thread_data = -CUDART_INF_F; - const int size_per_batch = gridDim.x * all_sequence_length; - - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - int base_index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { - float thread_data = -CUDART_INF_F; - int index = base_index + seq_idx; - if (rel_pos_bias == nullptr) { - thread_data = float(input[index]) * rsqrt_head_size; - } else { - T rel_pos_bias_value = broadcast_rel_pos_bias ? rel_pos_bias[index % size_per_batch] : rel_pos_bias[index]; - thread_data = float(input[index] + rel_pos_bias_value) * rsqrt_head_size; - } - const int sequence_index = blockIdx.x % sequence_length; + DECLARE_SOFTMAX_VARS(); + + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { + int index = offset + i; + float input_data = HAS_BIAS + ? float(input[index]) + float(attn_bias[bias_offset + i]) + : float(input[index]); + float thread_data = input_data * rsqrt_head_size; if (causal) { - int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. - if (seq_idx > from_index) { + int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length. + if (i > from_index) { thread_data = -CUDART_INF_F; } } int mask_offset = 0; - const int batch_index = blockIdx.y; if (mask_dimension == 2) { - mask_offset = batch_index * all_sequence_length + seq_idx; + mask_offset = b * total_sequence_length + i; } else if (mask_dimension == 3) { - mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + seq_idx; + mask_offset = (b * sequence_length + s) * total_sequence_length + i; } else if (mask_dimension == 4) { - int from_index = all_sequence_length - sequence_length + sequence_index; - mask_offset = (batch_index * max_sequence_length + from_index) * max_sequence_length + seq_idx; + int from_index = total_sequence_length - sequence_length + s; + mask_offset = (b * max_sequence_length + from_index) * max_sequence_length + i; } if (nullptr == key_padding_mask) { @@ -318,7 +334,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, if (skip_softmax) { output[index] = T(thread_data); } - cached_data[seq_idx] = thread_data; + cached_data[i] = thread_data; max_thread_data = max(max_thread_data, thread_data); } @@ -326,7 +342,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, return; } - const float max = BlockReduce(tmp_storage).Reduce(max_thread_data, cub::Max(), all_sequence_length); + const float max = BlockReduce(tmp_storage).Reduce(max_thread_data, cub::Max(), total_sequence_length); // Store max value if (threadIdx.x == 0) { @@ -335,9 +351,9 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, __syncthreads(); float sum_thread_data_exp = 0.0f; - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { - auto ev = expf(cached_data[seq_idx] - max_block); - cached_data[seq_idx] = ev; + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { + auto ev = expf(cached_data[i] - max_block); + cached_data[i] = ev; sum_thread_data_exp += ev; } const auto sum = BlockReduce(tmp_storage).Reduce(sum_thread_data_exp, cub::Sum(), TPB); @@ -348,18 +364,20 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, } __syncthreads(); - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { - output[base_index + seq_idx] = T(cached_data[seq_idx] * sum_reverse_block); + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { + output[offset + i] = T(cached_data[i] * sum_reverse_block); } } -template -__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, +// This kernel is for causal or not, raw attention mask (2D, 3D or 4D), and total_sequence_length <= 1024. +template +__device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length, const int sequence_length, const int* attention_mask, // 2D, 3D or 4D attention mask const bool* key_padding_mask, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal, @@ -374,31 +392,29 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; - const int size_per_batch = gridDim.x * all_sequence_length; + DECLARE_SOFTMAX_VARS(); + + int64_t index = offset + threadIdx.x; float thread_data = -CUDART_INF_F; - if (threadIdx.x < all_sequence_length) { + if (threadIdx.x < total_sequence_length) { thread_data = float(input[index]) * rsqrt_head_size; - const int sequence_index = blockIdx.x % sequence_length; if (causal) { - int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. + int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length. if (threadIdx.x > from_index) { thread_data = -CUDART_INF_F; } } int mask_offset = 0; - const int batch_index = blockIdx.y; if (mask_dimension == 2) { - mask_offset = batch_index * all_sequence_length + threadIdx.x; + mask_offset = b * total_sequence_length + threadIdx.x; } else if (mask_dimension == 3) { - mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x; + mask_offset = (b * sequence_length + s) * total_sequence_length + threadIdx.x; } else if (mask_dimension == 4) { - int from_index = all_sequence_length - sequence_length + sequence_index; - mask_offset = (batch_index * max_sequence_length + from_index) * max_sequence_length + threadIdx.x; + int from_index = total_sequence_length - sequence_length + s; + mask_offset = (b * max_sequence_length + from_index) * max_sequence_length + threadIdx.x; } if (nullptr == key_padding_mask) { @@ -412,20 +428,19 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } } - if (rel_pos_bias != nullptr) { - float bias = broadcast_rel_pos_bias ? float(rel_pos_bias[index % size_per_batch]) : float(rel_pos_bias[index]); - thread_data += bias; + if (HAS_BIAS) { + thread_data += float(attn_bias[bias_offset + threadIdx.x]); } } if (skip_softmax) { - if (threadIdx.x < all_sequence_length) { + if (threadIdx.x < total_sequence_length) { output[index] = T(thread_data); } return; } - const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length); + const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), total_sequence_length); // Store max value if (threadIdx.x == 0) { @@ -433,8 +448,8 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } __syncthreads(); - float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), all_sequence_length); + float thread_data_exp = threadIdx.x < total_sequence_length ? expf(thread_data - max_block) : 0.0f; + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), total_sequence_length); // Store value of 1.0/sum if (threadIdx.x == 0) { @@ -442,84 +457,97 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } __syncthreads(); - if (threadIdx.x < all_sequence_length) { + if (threadIdx.x < total_sequence_length) { output[index] = T(thread_data_exp * sum_reverse_block); } } -template -__global__ void SoftmaxKernelSmall(const int all_sequence_length, +template +__global__ void SoftmaxKernelSmall(const int total_sequence_length, const int sequence_length, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, bool causal) { - SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); + SoftmaxSmall(total_sequence_length, sequence_length, total_sequence_length, 0, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); } -template -__global__ void SoftmaxKernel(const int all_sequence_length, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, +template +__global__ void SoftmaxKernel(const int total_sequence_length, + const int sequence_length, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output) { - Softmax(all_sequence_length, all_sequence_length, 0, - rel_pos_bias, broadcast_rel_pos_bias, input, output); + Softmax(total_sequence_length, sequence_length, total_sequence_length, 0, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } template -Status ComputeSoftmax(cudaStream_t stream, const int all_sequence_length, const int sequence_length, - const int batch_size, const int num_heads, const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, T* input, T* output, bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - if (all_sequence_length <= 32) { - const int blockSize = 32; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 64) { - const int blockSize = 64; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 128) { - const int blockSize = 128; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 256) { - const int blockSize = 256; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 512) { - const int blockSize = 512; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (!causal) { - const int blockSize = 1024; - SoftmaxKernel<<>>( - all_sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output); - } else { - const int blockSize = 256; - const int sh_bytes = sizeof(float) * all_sequence_length; - SoftmaxLargeKernel<<>>( - all_sequence_length, sequence_length, all_sequence_length, 0, rel_pos_bias, broadcast_rel_pos_bias, - input, output, true); - } - +Status ComputeSoftmax(cudaStream_t stream, const int total_sequence_length, const int sequence_length, + const int batch_size, const int num_heads, const T* attn_bias, + const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, + T* input, T* output, bool causal) { + DISPATCH_BIAS(attn_bias, HAS_BIAS, [&] { + if (total_sequence_length <= 32) { + const int blockSize = 32; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 64) { + const int blockSize = 64; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 128) { + const int blockSize = 128; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 256) { + const int blockSize = 256; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 512) { + const int blockSize = 512; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (!causal) { + const int blockSize = 1024; + SoftmaxKernel<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); + } else { + const int blockSize = 256; + const int sh_bytes = sizeof(float) * total_sequence_length; + SoftmaxLargeKernel<<>>( + total_sequence_length, sequence_length, total_sequence_length, 0, attn_bias, + broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, true); + } + }); return CUDA_CALL(cudaGetLastError()); } -template -__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, +template +__global__ void MaskedSoftmaxKernelSmall(const int total_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, bool causal) { @@ -529,25 +557,27 @@ __global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, if (threadIdx.x == 0) { const int batch = blockIdx.y; start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); + end_position = min(total_sequence_length, mask_end[batch]); // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. if (start_position >= end_position) { start_position = 0; - end_position = all_sequence_length; + end_position = total_sequence_length; } } __syncthreads(); - SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); + SoftmaxSmall(total_sequence_length, sequence_length, end_position, start_position, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); } -template -__device__ inline void SoftmaxSmallPacked(const int sequence_length, +template +__device__ inline void SoftmaxSmallPacked(const int total_sequence_length, + const int sequence_length, const int end, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output) { using BlockReduce = cub::BlockReduce; @@ -556,23 +586,13 @@ __device__ inline void SoftmaxSmallPacked(const int sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * sequence_length; - const int index = offset + threadIdx.x; + DECLARE_SOFTMAX_VARS(); + + int64_t index = offset + threadIdx.x; bool is_valid = threadIdx.x < end; - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const bool no_rpb = (rel_pos_bias == nullptr); - const int size_per_batch = gridDim.x * sequence_length; - float input_data = no_rpb - ? float(input[index]) - : float(input[index] + (broadcast_rel_pos_bias - ? rel_pos_bias[index % size_per_batch] - : rel_pos_bias[index])); + float input_data = HAS_BIAS ? float(input[index]) + float(attn_bias[bias_offset + threadIdx.x]) : float(input[index]); float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F); const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end); @@ -596,16 +616,20 @@ __device__ inline void SoftmaxSmallPacked(const int sequence_length, } __syncthreads(); - // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. + // threadIdx.x might be larger than total_sequence_length due to alignment to 32x. if (threadIdx.x < sequence_length) { output[index] = T(thread_data_exp * sum_reverse_block); } } -template +template __global__ void SoftmaxKernelSmallWithCumSeqLen(const T* input, - const T* rel_pos_bias, const bool broadcast_rel_pos_bias, - const int* cum_seq_length, const int sequence_length, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, + const int* cum_seq_length, + const int total_sequence_length, + const int sequence_length, T* output) { __shared__ int end_position; @@ -615,15 +639,18 @@ __global__ void SoftmaxKernelSmallWithCumSeqLen(const T* input, } __syncthreads(); - SoftmaxSmallPacked(sequence_length, end_position, - rel_pos_bias, broadcast_rel_pos_bias, - input, output); + SoftmaxSmallPacked(total_sequence_length, sequence_length, end_position, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } -template +template __global__ void SoftmaxKernelWithCumSeqLen(const T* input, - const T* rel_pos_bias, const bool broadcast_rel_pos_bias, - const int* cum_seq_length, const int sequence_length, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, + const int* cum_seq_length, + const int total_sequence_length, + const int sequence_length, T* output) { __shared__ int end_position; @@ -633,16 +660,19 @@ __global__ void SoftmaxKernelWithCumSeqLen(const T* input, } __syncthreads(); - Softmax(sequence_length, end_position, 0 /*start_position*/, - rel_pos_bias, broadcast_rel_pos_bias, input, output); + constexpr int start_position = 0; + Softmax(total_sequence_length, sequence_length, end_position, start_position, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } -template -__global__ void MaskedSoftmaxKernel(const int all_sequence_length, +template +__global__ void MaskedSoftmaxKernel(const int total_sequence_length, + const int sequence_length, const int* mask_end, const int* mask_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output) { __shared__ int start_position; __shared__ int end_position; @@ -650,27 +680,28 @@ __global__ void MaskedSoftmaxKernel(const int all_sequence_length, if (threadIdx.x == 0) { const int batch = blockIdx.y; start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); + end_position = min(total_sequence_length, mask_end[batch]); // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. if (start_position >= end_position) { start_position = 0; - end_position = all_sequence_length; + end_position = total_sequence_length; } } __syncthreads(); - Softmax(all_sequence_length, end_position, start_position, - rel_pos_bias, broadcast_rel_pos_bias, input, output); + Softmax(total_sequence_length, sequence_length, end_position, start_position, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } -template -__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, +template +__global__ void SoftmaxWithRawMaskSmallKernel(const int total_sequence_length, const int sequence_length, const int* attention_mask, const bool* key_padding_mask, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal, @@ -679,9 +710,9 @@ __global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int max_sequence_length, const bool skip_softmax, const float mask_filter_value) { - SoftmaxWithRawMaskSmall( - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, output, + SoftmaxWithRawMaskSmall( + total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal, rsqrt_head_size, mask_dimension, max_sequence_length, skip_softmax, mask_filter_value); } @@ -689,107 +720,120 @@ __global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, template Status ComputeSoftmaxWithCumSeqLength( const T* input, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const int32_t* cum_seq_length, const int batch_size, const int sequence_length, + const int total_sequence_length, const int num_heads, T* output, cudaStream_t stream) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - - if (sequence_length <= 32) { - const int blockSize = 32; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); - - } else if (sequence_length <= 64) { - const int blockSize = 64; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); - } else if (sequence_length <= 128) { - const int blockSize = 128; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); - } else if (sequence_length <= 256) { - const int blockSize = 256; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); - } else if (sequence_length <= 512) { - const int blockSize = 512; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); - } else if (sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); - } else { - SoftmaxKernelWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); - } + DISPATCH_BIAS(attn_bias, HAS_BIAS, [&] { + if (sequence_length <= 32) { + const int blockSize = 32; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else if (sequence_length <= 64) { + const int blockSize = 64; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else if (sequence_length <= 128) { + const int blockSize = 128; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else if (sequence_length <= 256) { + const int blockSize = 256; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else if (sequence_length <= 512) { + const int blockSize = 512; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else if (sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else { + const int blockSize = 1024; + SoftmaxKernelWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } + }); return CUDA_CALL(cudaGetLastError()); } template Status ComputeSoftmaxWithMask1D(cudaStream_t stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* mask_index, const int* mask_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - - if (all_sequence_length <= 32) { - const int blockSize = 32; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 64) { - const int blockSize = 64; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 128) { - const int blockSize = 128; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 256) { - const int blockSize = 256; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 512) { - const int blockSize = 512; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (!causal) { - const int blockSize = 1024; - MaskedSoftmaxKernel - <<>>(all_sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention CUDA operator does not support total sequence length > 1024."); + DISPATCH_BIAS(attn_bias, HAS_BIAS, [&] { + if (total_sequence_length <= 32) { + const int blockSize = 32; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 64) { + const int blockSize = 64; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 128) { + const int blockSize = 128; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 256) { + const int blockSize = 256; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 512) { + const int blockSize = 512; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 1024) { + const int blockSize = 1024; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (!causal) { + const int blockSize = 1024; + MaskedSoftmaxKernel + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output); + } + }); + + if (total_sequence_length > 1024 && causal) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "ComputeSoftmaxWithMask1D does not support causal with total sequence length > 1024."); } return CUDA_CALL(cudaGetLastError()); @@ -797,14 +841,15 @@ Status ComputeSoftmaxWithMask1D(cudaStream_t stream, template Status ComputeSoftmaxWithRawMask(Stream* ort_stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* attention_mask, const bool* key_padding_mask, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal, @@ -815,69 +860,70 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, T* persistent_softmax_workspace, const float mask_filter_value) { auto stream = static_cast(ort_stream->GetHandle()); - const dim3 grid(sequence_length * num_heads, batch_size, 1); - T* out = use_persistent_softmax ? persistent_softmax_workspace : output; - if (all_sequence_length <= 32) { - const int blockSize = 32; - SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else if (all_sequence_length <= 64) { - const int blockSize = 64; - SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else if (all_sequence_length <= 128) { - const int blockSize = 128; - SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else if (all_sequence_length <= 256) { - const int blockSize = 256; - SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else if (all_sequence_length <= 512) { - const int blockSize = 512; - SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else { - const int blockSize = 256; - const int sh_bytes = sizeof(float) * all_sequence_length; - SoftmaxWithRawMaskLargeKernel - <<>>( - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } + + DISPATCH_BIAS(attn_bias, HAS_BIAS, [&] { + if (total_sequence_length <= 32) { + const int blockSize = 32; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else if (total_sequence_length <= 64) { + const int blockSize = 64; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else if (total_sequence_length <= 128) { + const int blockSize = 128; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else if (total_sequence_length <= 256) { + const int blockSize = 256; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else if (total_sequence_length <= 512) { + const int blockSize = 512; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else if (total_sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else { + const int blockSize = 256; + const int sh_bytes = sizeof(float) * total_sequence_length; + SoftmaxWithRawMaskLargeKernel + <<>>( + total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } + }); if (use_persistent_softmax) { return onnxruntime::cuda::dispatch_warpwise_softmax_forward( ort_stream, output, persistent_softmax_workspace, - all_sequence_length, - all_sequence_length, + total_sequence_length, + total_sequence_length, batch_size * num_heads * sequence_length); } @@ -886,70 +932,79 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, // Template Instantiation template Status ComputeSoftmax( - cudaStream_t stream, const int all_sequence_length, const int sequence_length, - const int batch_size, const int num_heads, const float* rel_pos_bias, - const bool broadcast_rel_pos_bias, float* input, float* output, bool causal); + cudaStream_t stream, const int total_sequence_length, const int sequence_length, + const int batch_size, const int num_heads, const float* attn_bias, + const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, + float* input, float* output, bool causal); template Status ComputeSoftmax( - cudaStream_t stream, const int all_sequence_length, const int sequence_length, - const int batch_size, const int num_heads, const half* rel_pos_bias, - const bool broadcast_rel_pos_bias, half* input, half* output, bool causal); + cudaStream_t stream, const int total_sequence_length, const int sequence_length, + const int batch_size, const int num_heads, const half* attn_bias, + const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, + half* input, half* output, bool causal); template Status ComputeSoftmaxWithCumSeqLength( const float* input, - const float* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const float* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const int32_t* cum_seq_length, const int batch_size, const int sequence_length, + const int total_sequence_length, const int num_heads, float* output, cudaStream_t stream); template Status ComputeSoftmaxWithCumSeqLength( const half* input, - const half* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const half* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const int32_t* cum_seq_length, const int batch_size, const int sequence_length, + const int total_sequence_length, const int num_heads, half* output, cudaStream_t stream); template Status ComputeSoftmaxWithMask1D(cudaStream_t stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* mask_index, const int* mask_start, - const float* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const float* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const float* input, float* output, const bool causal); template Status ComputeSoftmaxWithMask1D(cudaStream_t stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* mask_index, const int* mask_start, - const half* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const half* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const half* input, half* output, const bool causal); template Status ComputeSoftmaxWithRawMask(Stream* ort_stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* attention_mask, const bool* key_padding_mask, - const float* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const float* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const float* input, float* output, const bool causal, @@ -961,14 +1016,15 @@ template Status ComputeSoftmaxWithRawMask(Stream* ort_stream, const float mask_filter_value); template Status ComputeSoftmaxWithRawMask(Stream* ort_stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* attention_mask, const bool* key_padding_mask, - const half* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const half* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const half* input, half* output, const bool causal, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h index 46d2423fa7009..f7fab268b4607 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h @@ -10,16 +10,19 @@ namespace attention_softmax_cuda { template Status ComputeSoftmax(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, T* input, T* output, bool causal); + const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, + T* input, T* output, bool causal); template Status ComputeSoftmaxWithCumSeqLength( const T* input, const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const int32_t* cum_seq_length, const int batch_size, const int sequence_length, + const int total_sequence_length, const int num_heads, T* output, cudaStream_t stream); @@ -32,7 +35,8 @@ Status ComputeSoftmaxWithMask1D(cudaStream_t stream, const int* mask_index, const int* mask_start, const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal); @@ -46,7 +50,8 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, const int* attention_mask, const bool* key_padding_mask, const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal, diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index a5de20e44be1a..1598a7e8bcf1e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -184,35 +184,41 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.q_strideH = params.qk_head_size; p.k_strideH = params.qk_head_size; p.v_strideH = params.v_head_size; - p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; p.q_strideM = params.num_heads * params.qk_head_size; p.k_strideM = params.num_heads * params.qk_head_size; p.v_strideM = params.num_heads * params.v_head_size; p.o_strideM = params.num_heads * params.v_head_size; - p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; - p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; } else { // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH p.q_strideH = params.qk_head_size; p.k_strideH = params.max_sequence_length * params.qk_head_size; p.v_strideH = params.max_sequence_length * params.v_head_size; - p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; p.q_strideM = params.num_heads * params.qk_head_size; p.k_strideM = params.qk_head_size; p.v_strideM = params.v_head_size; p.o_strideM = params.num_heads * params.v_head_size; - p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; - p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } + + if (params.attn_bias != nullptr) { + p.bias_strideH = params.broadcast_attn_bias_dim_1 ? 0 : p.num_queries * p.num_keys; + p.bias_strideM = p.num_keys; + p.bias_strideB = params.broadcast_attn_bias_dim_0 + ? 0 + : ((params.broadcast_attn_bias_dim_1 ? 1 : params.num_heads) * p.num_queries * p.num_keys); + } else { + p.bias_strideH = 0; + p.bias_strideM = 0; + p.bias_strideB = 0; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index 08a562a12b844..a9777800f6038 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -25,8 +25,6 @@ struct MemoryEfficientAttentionParams { int32_t qk_head_size; int32_t v_head_size; bool causal; - // The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models. - bool is_attn_bias_batched; float scale; @@ -37,9 +35,12 @@ struct MemoryEfficientAttentionParams { const void* query; // [B, S, N, H] const void* key; // [B, L, N, H], where L is kv_sequence_length const void* value; // [B, L, N, H_v] - const void* attn_bias; // [N, S, S*] or null - void* output; // [B, S, N, H_v] - void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise + const void* attn_bias; // [B or 1, N or 1, S, L] or null + bool broadcast_attn_bias_dim_0; + bool broadcast_attn_bias_dim_1; + + void* output; // [B, S, N, H_v] + void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise cudaStream_t stream; static bool need_workspace(size_t v_head_size, bool is_float) { diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu index c0b1996789183..65d2c113576f6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -140,21 +140,25 @@ Status DecoderQkvToContext( } constexpr bool is_unidirectional = false; - const T* add_before_softmax = nullptr; + const T* attention_bias = nullptr; + constexpr bool broadcast_attn_bias_dim_0 = false; + constexpr bool broadcast_attn_bias_dim_1 = false; + if (has_key_padding_mask) { constexpr int mask_dimension = 2; constexpr int max_sequence_length = 0; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( ort_stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, key_padding_mask, add_before_softmax, - false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, + num_heads, nullptr, key_padding_mask, + attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + scratch1, scratch2, is_unidirectional, 1.0f, mask_dimension, max_sequence_length, false, nullptr, mask_filter_value)); } else { ORT_RETURN_IF_ERROR(ComputeSoftmax( stream, kv_sequence_length, sequence_length, batch_size, num_heads, - add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, - is_unidirectional)); + attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + scratch1, scratch2, is_unidirectional)); } // compute P*V (as V*P), and store in scratch3: BxNxSxH diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 037a4fdf3d9a0..350c4718c437e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -60,7 +60,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* const Tensor* key = context->Input(1); const Tensor* value = context->Input(2); const Tensor* mask_index = context->Input(3); - const Tensor* relative_position_bias = context->Input(4); + const Tensor* attention_bias = context->Input(4); const Tensor* past_key = context->Input(kPastInputIndex); const Tensor* past_value = context->Input(kPastInputIndex + 1); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); @@ -80,7 +80,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* value, bias, mask_index, - relative_position_bias, + attention_bias, past_key, past_value, past_seq_len, @@ -141,16 +141,16 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* // Update the q buffers parameters.q = const_cast(query->Data()); - // Update the relative position bias for self attention - if (relative_position_bias != nullptr) { - parameters.relative_attention_bias = const_cast(relative_position_bias->Data()); + // Update the attention bias for self attention + if (attention_bias != nullptr) { + parameters.attention_bias = const_cast(attention_bias->Data()); } // Decoder cross-attention if (past_key == nullptr && present_key == nullptr) { - if (relative_position_bias != nullptr) { + if (attention_bias != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "DecoderMaskedMultiHeadAttention does not support relative position bias for cross-attention"); + "DecoderMaskedMultiHeadAttention does not support attention bias for cross-attention"); } parameters.is_cross_attention = true; diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc index 07a6fbd60e171..e7d117686a538 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc @@ -45,7 +45,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(kPastInputIndex); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); const Tensor* beam_width = context->Input(kBeamWidthInputIndex); const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex); @@ -61,7 +61,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont bias->Shape(), mask_index, past, - relative_position_bias, + attention_bias, ¶meters, device_prop.maxThreadsPerBlock, past_seq_len)); @@ -85,8 +85,8 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont } // TODO(hasesh): If there is a need, we will support this later - if (relative_position_bias != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "DecoderMaskedSelfAttention does not support relative position bias currently"); + if (attention_bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "DecoderMaskedSelfAttention does not support attention bias currently"); } // TODO(hasesh): Support more mask types. Currently, it only supports the HuggingFace GreedySearch/BeamSearch pattern. diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 2f8d277cb7342..8edae863ff44e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -154,6 +154,15 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // The offset in the Q and K buffer also accounts for the batch. int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; + // The offset of attention bias for current head. + // Support broadcasting the first and second dimensions of attention bias with shape + // [batch_size or 1, num_heads or 1, seq_len, total_seq_len], and asssume seq_len == 1 for this operator. + int attn_bias_offset = (params.attention_bias == nullptr) + ? 0 + : (((params.broadcast_attn_bias_dim_0 ? 0 : (bbi * params.num_heads)) + + (params.broadcast_attn_bias_dim_1 ? 0 : hi)) * + params.total_sequence_length); + // Trigger the loads from the Q and K buffers. Qk_vec_k q; zero(q); @@ -286,9 +295,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio if (tidx == 0) { // Normalize qk. qk *= inv_sqrt_dh; - if (params.relative_attention_bias != nullptr) { - qk = add_vec(qk, - reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + tlength]); + if (params.attention_bias != nullptr) { + qk = add_vec(qk, reinterpret_cast(params.attention_bias)[attn_bias_offset + tlength]); } qk_max = qk; qk_smem[tlength] = qk; @@ -386,9 +394,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // Store the product to shared memory. There's one qk value per timestep. Update the max. if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - if (params.relative_attention_bias != nullptr) { - qk = add_vec(qk, - reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + ti]); + if (params.attention_bias != nullptr) { + qk = add_vec(qk, reinterpret_cast(params.attention_bias)[attn_bias_offset + ti]); } qk_max = fmaxf(qk_max, qk); qk_smem[ti] = qk; @@ -479,9 +486,9 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio #pragma unroll for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) { if (time_bounds_cond[k_unroll] && (tidx % THREADS_PER_KEY == 0)) { - if (params.relative_attention_bias != nullptr) { + if (params.attention_bias != nullptr) { qk[k_unroll] = add_vec(qk[k_unroll], - reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + time_step[k_unroll]]); + reinterpret_cast(params.attention_bias)[attn_bias_offset + time_step[k_unroll]]); } qk_max = fmaxf(qk_max, qk[k_unroll]); qk_smem[time_step[k_unroll]] = qk[k_unroll]; diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index 1a17757d1ec2d..efad33855328f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -37,7 +37,7 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { void* v = nullptr; void* v_bias = nullptr; - void* relative_attention_bias = nullptr; + void* attention_bias = nullptr; void* k_cache = nullptr; void* v_cache = nullptr; @@ -68,4 +68,4 @@ void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cud } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 3099b52cce13e..b694de48d2961 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -836,7 +836,6 @@ Status EfficientAttention( p.key = key; p.value = value; p.attn_bias = nullptr; - p.is_attn_bias_batched = false; p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 2835192abd298..b2fd9b5e89de1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -74,7 +74,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* value = context->Input(2); const Tensor* bias = context->Input(3); const Tensor* key_padding_mask = context->Input(4); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); const Tensor* past_key = context->Input(6); const Tensor* past_value = context->Input(7); @@ -87,7 +87,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { value, bias, key_padding_mask, - relative_position_bias, + attention_bias, past_key, past_value, nullptr, // past_seq_len @@ -150,7 +150,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && - nullptr == relative_position_bias && + nullptr == attention_bias && nullptr == key_padding_mask && parameters.head_size == parameters.v_head_size && onnxruntime::flash::is_supported(device_prop, @@ -188,7 +188,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_fused_cross_attention_ && nullptr == key_padding_mask && - nullptr == relative_position_bias && + nullptr == attention_bias && nullptr == past_key && nullptr == present_key && (parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) && parameters.hidden_size == parameters.v_hidden_size && @@ -212,7 +212,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_fused_self_attention_ && fused_cross_attention_kernel == nullptr && - nullptr == relative_position_bias && + nullptr == attention_bias && (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && nullptr == past_key && nullptr == present_key && is_mask_none_or_1d_k_len && @@ -243,16 +243,14 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length >= length_threshold || parameters.kv_sequence_length >= length_threshold; - // Check whether the relative position bias alignment is good for memory efficient attention. - bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; - bool use_memory_efficient_attention = !use_flash_attention && fused_runner == nullptr && fused_cross_attention_kernel == nullptr && !disable_memory_efficient_attention_ && is_long_sequence && - (relative_position_bias == nullptr || is_good_for_rpb) && + // Check whether the attention bias alignment is good for memory efficient attention. + (attention_bias == nullptr || parameters.sequence_length % (4 * sizeof(T)) == 0) && (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && has_memory_efficient_attention(sm, std::is_same::value, parameters.head_size, parameters.v_head_size); @@ -270,7 +268,9 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); - data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + if (nullptr != attention_bias) { + data.attention_bias = reinterpret_cast(attention_bias->Data()); + } data.output = reinterpret_cast(output->MutableData()); data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index d1c6993d48e62..0e5300f32da3c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -9,6 +9,7 @@ #include "contrib_ops/cuda/bert/packed_attention_impl.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -42,11 +43,12 @@ TrtFusedAttention::TrtFusedAttention(const OpKernelInfo& info) template MHARunner* TrtFusedAttention::GetFusedRunner(const cudaDeviceProp& device_prop, + bool has_attention_bias, const PackedAttentionParameters& parameters) const { MHARunner* fused_runner = nullptr; bool use_fused_runner = !disable_fused_runner_ && - !parameters.has_relative_position_bias && + !has_attention_bias && parameters.hidden_size == parameters.v_hidden_size; if (!use_fused_runner) { @@ -104,7 +106,7 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const TensorShape& token_offset_shape, const TensorShape& cu_seq_len_shape, - const Tensor* relative_position_bias, + const Tensor* attention_bias, PackedAttentionParameters& parameters) const { // Abbreviation and Meanings: // T: token_count @@ -123,7 +125,7 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, // bias (Q/K/V) : (D + D + D_v) // token_offset : (B, S) // cu_seq_len_shape : (B + 1) - // relative_position_bias : (B, N, S, S), (1, N, S, S) or NULL + // attention_bias : (B, N, S, S), (1, N, S, S) or NULL const auto& input_dims = input_shape.GetDims(); if (input_dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -204,43 +206,14 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, v_hidden_size, "bias_dims[0]=", bias_dims[0]); } - bool broadcast_res_pos_bias = false; - if (relative_position_bias != nullptr) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - - if (relative_position_bias_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); - } - - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be same as batch_size or 1, got ", - relative_position_bias_dims[0]); - } - if (relative_position_bias_dims[0] == 1) { - broadcast_res_pos_bias = true; - } - - if (relative_position_bias_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); - } - - if (relative_position_bias_dims[2] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); - } - - if (relative_position_bias_dims[3] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as sequence_length, got ", - relative_position_bias_dims[3]); - } + gsl::span attention_bias_dims; + if (attention_bias != nullptr) { + attention_bias_dims = attention_bias->Shape().GetDims(); + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckAttentionBias( + attention_bias_dims, batch_size, num_heads, sequence_length, sequence_length)); } + parameters.broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1; + parameters.broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1; parameters.batch_size = static_cast(batch_size); parameters.sequence_length = static_cast(sequence_length); @@ -252,8 +225,6 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, parameters.num_heads = num_heads; parameters.scale = this->GetScale(); parameters.token_count = static_cast(token_count); - parameters.has_relative_position_bias = nullptr != relative_position_bias; - parameters.broadcast_res_pos_bias = broadcast_res_pos_bias; return Status::OK(); } @@ -265,7 +236,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* token_offset = context->Input(3); const Tensor* cumulative_sequence_length = context->Input(4); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); PackedAttentionParameters parameters; parameters.use_tf32 = this->UseTF32(); @@ -274,22 +245,21 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { bias->Shape(), token_offset->Shape(), cumulative_sequence_length->Shape(), - relative_position_bias, + attention_bias, parameters)); TensorShapeVector output_shape{parameters.token_count, parameters.v_hidden_size}; Tensor* output = context->Output(0, output_shape); auto& device_prop = this->GetDeviceProp(); - MHARunner* fused_runner = this->GetFusedRunner(device_prop, parameters); + MHARunner* fused_runner = this->GetFusedRunner(device_prop, attention_bias != nullptr, parameters); bool use_memory_efficient_attention = false; #if USE_MEMORY_EFFICIENT_ATTENTION if (nullptr == fused_runner) { int sm = device_prop.major * 10 + device_prop.minor; - bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; use_memory_efficient_attention = - is_good_for_rpb && + (attention_bias == nullptr || parameters.sequence_length % (4 * sizeof(T)) == 0) && sizeof(T) == 2 && // only enable for fp16 has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); } @@ -346,7 +316,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { PackedAttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); data.bias = reinterpret_cast(bias->Data()); - data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + data.attention_bias = (nullptr == attention_bias) ? nullptr : reinterpret_cast(attention_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.token_offset = token_offset->Data(); data.cumulative_sequence_length = cumulative_sequence_length->Data(); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h index 67b420764169a..6fcacd4d46ada 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h @@ -23,7 +23,9 @@ class TrtFusedAttention : public CudaKernel { TrtFusedAttention(const OpKernelInfo& info); protected: - MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, const PackedAttentionParameters& parameters) const; + MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, + bool has_attention_bias, + const PackedAttentionParameters& parameters) const; protected: const AttentionKernelOptions* kernel_options_; @@ -46,7 +48,7 @@ class PackedAttention final : public TrtFusedAttention { const TensorShape& bias_shape, const TensorShape& packing_token_offset_shape, const TensorShape& cu_seq_len_shape, - const Tensor* relative_position_bias, + const Tensor* attention_bias, PackedAttentionParameters& parameters) const; int GetNumHeads() const { return num_heads_; } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 2521cd49b5482..849a57512dc3d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -523,8 +523,11 @@ Status FusedScaledDotProductAttentionCutlass( p.query = query; p.key = key; p.value = value; - p.attn_bias = data.relative_position_bias; - p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; + + p.attn_bias = data.attention_bias; + p.broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; + p.broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; + p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? accum_workspace : nullptr; p.stream = stream; @@ -603,14 +606,19 @@ Status UnfusedScaledDotProductAttention( sequence_length); T* attention_score = scaled_qk + (bytes / element_size); + const bool broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; + const bool broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; + // Apply softmax and store result R to attention_score: BxNxSxS ORT_RETURN_IF_ERROR(ComputeSoftmaxWithCumSeqLength( scaled_qk, - data.relative_position_bias, - parameters.broadcast_res_pos_bias, + data.attention_bias, + broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1, data.cumulative_sequence_length, batch_size, sequence_length, + sequence_length, // total sequence length num_heads, attention_score, stream)); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h index 629ca59c73f16..1126c8a046da9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h @@ -33,7 +33,7 @@ template struct PackedAttentionData { T* gemm_buffer; const T* bias; - const T* relative_position_bias; + const T* attention_bias; const int32_t* token_offset; const int32_t* cumulative_sequence_length; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 53e96fc732a33..72a4c776d4fce 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -10,6 +10,7 @@ #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -54,7 +55,7 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, const Tensor* bias, const TensorShape& token_offset_shape, const TensorShape& cu_seq_len_shape, - const Tensor* relative_position_bias, + const Tensor* attention_bias, PackedAttentionParameters& parameters) const { // Shapes of inputs and output: // When Q, K and V are not packed: @@ -67,7 +68,7 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, // Input 'value': None // Input 'token_offset': (batch_size, sequence_length) // Input 'cumulative_sequence_length': (batch_size + 1) - // Input 'relative_position_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None + // Input 'attention_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None // Output 'output': (token_count, v_hidden_size) const auto& query_dims = query_shape.GetDims(); @@ -147,45 +148,16 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, "Input 'cumulative_sequence_length' should have 1 dimension with size equal to batch_size + 1"); } - // TODO(tianleiwu): move relative position bias shape checker to a helper function. It is shared by multiple ops. const int num_heads = this->GetNumHeads(); - bool broadcast_res_pos_bias = false; - if (relative_position_bias != nullptr) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - if (relative_position_bias_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); - } - - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be same as batch_size or 1, got ", - relative_position_bias_dims[0]); - } - if (relative_position_bias_dims[0] == 1 && 1 != batch_size) { - broadcast_res_pos_bias = true; - } - - if (relative_position_bias_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); - } - - if (relative_position_bias_dims[2] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); - } - - if (relative_position_bias_dims[3] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as sequence_length, got ", - relative_position_bias_dims[3]); - } + gsl::span attention_bias_dims; + if (attention_bias != nullptr) { + attention_bias_dims = attention_bias->Shape().GetDims(); + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckAttentionBias( + attention_bias_dims, batch_size, num_heads, sequence_length, sequence_length)); } + parameters.broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1; + parameters.broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1; parameters.batch_size = static_cast(batch_size); parameters.sequence_length = static_cast(sequence_length); @@ -197,8 +169,6 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, parameters.num_heads = num_heads; parameters.scale = this->GetScale(); parameters.token_count = static_cast(token_count); - parameters.has_relative_position_bias = (nullptr != relative_position_bias); - parameters.broadcast_res_pos_bias = broadcast_res_pos_bias; return Status::OK(); } @@ -211,7 +181,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co const Tensor* bias = context->Input(3); const Tensor* token_offset = context->Input(4); const Tensor* cumulative_sequence_length = context->Input(5); - const Tensor* relative_position_bias = context->Input(6); + const Tensor* attention_bias = context->Input(6); PackedAttentionParameters parameters; parameters.use_tf32 = this->UseTF32(); @@ -221,7 +191,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co bias, token_offset->Shape(), cumulative_sequence_length->Shape(), - relative_position_bias, + attention_bias, parameters)); TensorShapeVector output_shape{parameters.token_count, parameters.v_hidden_size}; @@ -232,7 +202,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co bool use_flash_attention = false; #if USE_FLASH_ATTENTION if (!disable_flash_attention_) { - use_flash_attention = !parameters.has_relative_position_bias && + use_flash_attention = nullptr == attention_bias && parameters.head_size == parameters.v_head_size && onnxruntime::flash::is_supported(device_prop, parameters.head_size, @@ -247,16 +217,17 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co } #endif - MHARunner* fused_runner = use_flash_attention ? nullptr : this->GetFusedRunner(device_prop, parameters); + MHARunner* fused_runner = use_flash_attention + ? nullptr + : this->GetFusedRunner(device_prop, attention_bias != nullptr, parameters); bool use_memory_efficient_attention = false; #if USE_MEMORY_EFFICIENT_ATTENTION if (!use_flash_attention && nullptr == fused_runner && !disable_memory_efficient_attention_) { int sm = device_prop.major * 10 + device_prop.minor; - bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; use_memory_efficient_attention = - is_good_for_rpb && + (nullptr == attention_bias || parameters.sequence_length % (4 * sizeof(T)) == 0) && (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); } @@ -304,9 +275,9 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co data.key = (key == nullptr) ? nullptr : reinterpret_cast(key->Data()); data.value = (value == nullptr) ? nullptr : reinterpret_cast(value->Data()); data.bias = (bias == nullptr) ? nullptr : reinterpret_cast(bias->Data()); - data.relative_position_bias = (nullptr == relative_position_bias) - ? nullptr - : reinterpret_cast(relative_position_bias->Data()); + data.attention_bias = (nullptr == attention_bias) + ? nullptr + : reinterpret_cast(attention_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.token_offset = token_offset->Data(); data.cumulative_sequence_length = cumulative_sequence_length->Data(); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h index 9b52a70fc6181..3e59ce3dd229e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h @@ -23,7 +23,7 @@ class PackedMultiHeadAttention final : public TrtFusedAttention { const Tensor* bias, const TensorShape& token_offset_shape, const TensorShape& cu_seq_len_shape, - const Tensor* relative_position_bias, + const Tensor* attention_bias, PackedAttentionParameters& parameters) const; int GetNumHeads() const { return num_heads_; } float GetScale() const { return scale_; } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index e5a4c54f48903..c00eefc8e49de 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -701,8 +701,11 @@ Status FusedAttentionCutlass( p.query = data.no_qkv_workspace ? data.query : data.workspace; p.key = data.no_qkv_workspace ? data.key : (data.workspace + elements_qk); p.value = data.no_qkv_workspace ? data.value : (data.workspace + elements_qk + elements_qk); - p.attn_bias = data.relative_position_bias; - p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; + + p.attn_bias = data.attention_bias; + p.broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; + p.broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; + p.output = data.output; p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) @@ -791,11 +794,13 @@ Status UnfusedAttention( // Apply softmax and store result R to attention_score: BxNxSxS ORT_RETURN_IF_ERROR(ComputeSoftmaxWithCumSeqLength( scaled_qk, - data.relative_position_bias, - parameters.broadcast_res_pos_bias, + data.attention_bias, + parameters.broadcast_attn_bias_dim_0, + parameters.broadcast_attn_bias_dim_1, data.cumulative_sequence_length, batch_size, sequence_length, + sequence_length, // total sequence length num_heads, attention_score, stream)); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h index eeca72f16e64e..9d0ff77e5fcaa 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h @@ -17,7 +17,8 @@ struct PackedMultiHeadAttentionData { const T* key; const T* value; const T* bias; - const T* relative_position_bias; + const T* attention_bias; + const int32_t* token_offset; const int32_t* cumulative_sequence_length; diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index b62e566d43f89..3a5fc401c53af 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -52,7 +52,7 @@ Status QAttention::CheckInputs(const Tensor* input, auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past_tensor, - nullptr, // relative_position_bias + nullptr, // attention_bias parameters, device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 12835978536e1..3e93a527877c5 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -199,7 +199,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), merged_weights_shape, merged_bias_shape, mask_index, nullptr, // past - nullptr, // relative_position_bias + nullptr, // attention_bias nullptr, // parameters device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h index b4b501856a52e..62c1679743429 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h @@ -17,4 +17,4 @@ DefineQOrderedAttentionInput(Input, input, 0), DefineQOrderedAttentionInput(Scale_Values_Gemm, scale_values_gemm, 16), DefineQOrderedAttentionInput(Mask_Index, mask_index, 17), DefineQOrderedAttentionInput(Past, past, 18), - DefineQOrderedAttentionInput(relative_position_bias, relative_position_bias, 19) + DefineQOrderedAttentionInput(attention_bias, attention_bias, 19) diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index 6d52ff7282799..5c39cf56dfd92 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -335,6 +335,29 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } } +void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} +void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + #else CudaTensorConsoleDumper::CudaTensorConsoleDumper() { } @@ -410,6 +433,25 @@ void CudaTensorConsoleDumper::Print(const char*, int, bool) const { void CudaTensorConsoleDumper::Print(const char*, const std::string&, bool) const { } + +void CudaTensorConsoleDumper::Print(const char*, const int32_t*, gsl::span&) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const int64_t*, gsl::span&) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const float*, gsl::span&) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const half*, gsl::span&) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, gsl::span&) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, gsl::span&) const { +} + #endif } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 4f41161cd4a31..631421b1623be 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -21,26 +21,32 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; + void Print(const char* name, const int32_t* tensor, gsl::span& dims) const override; void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; + void Print(const char* name, const int64_t* tensor, gsl::span& dims) const override; void Print(const char* name, const float* tensor, int dim0, int dim1) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const override; - - void Print(const char* name, const half* tensor, int dim0, int dim1) const; - void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; - void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; + void Print(const char* name, const float* tensor, gsl::span& dims) const override; void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const override; + void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const override; + + void Print(const char* name, const half* tensor, int dim0, int dim1) const; + void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; + void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; + void Print(const char* name, const half* tensor, gsl::span& dims) const; void Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const; void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const; void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + void Print(const char* name, const BFloat16* tensor, gsl::span& dims) const; void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu index 96cc17734874c..473ab8dd3ce4d 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cu @@ -53,7 +53,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); auto& device_prop = GetDeviceProp(); @@ -63,7 +63,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bias->Shape(), mask_index, past, - relative_position_bias, + attention_bias, &attn, device_prop.maxThreadsPerBlock, past_seq_len)); @@ -190,8 +190,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { params.v_buffer = v_buffer; params.out_buffer = reinterpret_cast(output->MutableDataRaw()); - if (relative_position_bias != nullptr) { - params.bias_buffer = reinterpret_cast(relative_position_bias->DataRaw()); + if (attention_bias != nullptr) { + params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); } if (mask_index != nullptr) { diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 54dda4bfa6d2c..e013f35e150c4 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -398,7 +398,8 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { const T* v_buffer; T* out_buffer; - // optional, bias [B,N,S,T] + // optional, attention bias [B,N,S,T] + // TODO: support shape [B,1,S,T], [1, N, S, T], [1, 1, S, T] with broadcast. const T* bias_buffer{nullptr}; // optional, mask value diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 5997daaca6e8a..b07f9214e340e 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -87,7 +87,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias{}; const Tensor* key_padding_mask{}; - const Tensor* relative_position_bias{}; + const Tensor* attention_bias{}; const Tensor* past_key{}; const Tensor* past_value{}; const Tensor* past_seq_len{}; @@ -95,12 +95,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { if (attn_type_ == kMultiHeadAttention) { bias = context->Input(3); key_padding_mask = context->Input(4); - relative_position_bias = context->Input(5); + attention_bias = context->Input(5); past_key = context->Input(6); past_value = context->Input(7); } else if (attn_type_ == kDecoderMaskedMultiHeadAttention) { key_padding_mask = context->Input(3); - relative_position_bias = context->Input(4); + attention_bias = context->Input(4); past_key = context->Input(5); past_value = context->Input(6); past_seq_len = context->Input(kPastSequenceLengthInputIndex); @@ -120,7 +120,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR( multihead_attention_helper::CheckInputs( query, key, value, bias, - key_padding_mask, relative_position_bias, + key_padding_mask, attention_bias, past_key, past_value, past_seq_len, &attn, num_heads_, mask_filter_value_, scale_, false, /*is_unidirectional_*/ @@ -263,8 +263,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { params.mask_index_dims = key_padding_mask->Shape().AsShapeVector(); } - if (relative_position_bias != nullptr) { - params.bias_buffer = reinterpret_cast(relative_position_bias->DataRaw()); + if (attention_bias != nullptr) { + params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); } params.workspace_buffer = reinterpret_cast(workspace.get()); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7272a949f7218..334090e8f305f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -421,8 +421,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T", OpSchema::Optional) .Input(5, - "relative_position_bias", - "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", + "attention_bias", + "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) .Input(6, @@ -482,7 +482,7 @@ The operator only supports BERT like model with padding on right now. // Input 'bias': (hidden_size + hidden_size + v_hidden_size) // Input 'token_offset': (batch_size, sequence_length) // Input 'cumulative_sequence_length': (batch_size + 1) -// Input 'relative_position_bias': (batch_size, num_heads, sequence_length, sequence_length) +// Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) // Output 'output': (token_count, v_hidden_size) void PackedAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference @@ -560,9 +560,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.", "M") .Input(5, - "relative_position_bias", - "A tensor with shape (batch_size, num_heads, sequence_length, sequence_length)" - "or (1, num_heads, sequence_length, sequence_length)." + "attention_bias", + "A tensor with shape (batch_size or 1, num_heads or 1, sequence_length, sequence_length)." "It specifies the additional bias to QxK'", "T", OpSchema::Optional) @@ -616,7 +615,7 @@ The operator only supports BERT like model with padding on right now. // Input 'bias': (hidden_size + hidden_size + v_hidden_size) // Input 'token_offset': (batch_size, sequence_length) // Input 'cumulative_sequence_length': (batch_size + 1) -// Input 'relative_position_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None +// Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) or None // Output 'output': (token_count, v_hidden_size) void PackedMultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference @@ -694,9 +693,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.", "M") .Input(6, - "relative_position_bias", - "It specifies the additional bias to QxK'. The shape is (batch_size, num_heads, sequence_length, sequence_length)" - " or (1, num_heads, sequence_length, sequence_length)", + "attention_bias", + "It specifies the additional bias to QxK'. " + "The shape is (batch_size or 1, num_heads or 1, sequence_length, sequence_length)", "T", OpSchema::Optional) .Output(0, @@ -778,8 +777,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.", "T") .Input(5, - "relative_position_bias", - "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", + "attention_bias", + "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) .Input(6, @@ -871,7 +870,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M", OpSchema::Optional) .Input(4, - "relative_position_bias", + "attention_bias", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", "T", OpSchema::Optional) @@ -1006,9 +1005,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M", OpSchema::Optional) .Input(5, - "relative_position_bias", - "relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)" - " or (1, num_heads, sequence_length, total_sequence_length)", + "attention_bias", + "bias added to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) .Input(6, diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 762d892c45ce8..6f1f1c831d191 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -1146,7 +1146,7 @@ where value of each element is the end position, or valid length of actual seque left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. -Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. +Current version does not support past/present, attention_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. )DOC"; @@ -1208,8 +1208,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(18, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "Q", OpSchema::Optional) - .Input(19, "relative_position_bias", - "additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).", "S", + .Input(19, "attention_bias", + "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length).", "S", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "Q") .TypeConstraint("Q", {"tensor(int8)"}, "Constrain input and output types to int8 tensors.") diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index fd7b19dea724d..ce9780031a250 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -225,12 +225,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) override { return p->contrib::AttentionBase::CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, - relative_position_bias, + attention_bias, parameters, max_threads_per_block, past_seq_len); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 840d6f8e3e7aa..eb1569c3e499e 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -163,7 +163,7 @@ struct ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) = 0; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 14a7383e67897..788293464d3b3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -2371,7 +2371,7 @@ constexpr DML_SCHEMA_FIELD DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA_FIELDS[18] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "StackedQueryKeyValueTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MaskTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RelativePositionBiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AttentionBiasTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastKeyTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastValueTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -2502,7 +2502,7 @@ constexpr DML_SCHEMA_FIELD DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA_FIELDS[20] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "StackedQueryKeyValueTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MaskTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RelativePositionBiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AttentionBiasTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastKeyTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastValueTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastSequenceLengthsTensor", true }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index 73c2d57e984af..9b4a34622d460 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -47,7 +47,7 @@ class DmlOperatorAttention : public DmlOperator mhaStackedQueryKeyValueIndex, mhaBiasIndex, mhaMaskIndex, - mhaRelativePositionBiasIndex, + mhaAttentionBiasIndex, mhaPastKeyIndex, mhaPastValueIndex, mhaInputCount, @@ -60,7 +60,7 @@ class DmlOperatorAttention : public DmlOperator biasIndex, maskIndex, pastIndex, - relativePositionBiasIndex, + attentionBiasIndex, pastSequenceLengthIndex, inputCount, }; @@ -74,16 +74,16 @@ class DmlOperatorAttention : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 2); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1); - const uint32_t dmlInputIndex = inputIndex; - const uint32_t dmlWeightsIndex = weightsIndex; - const uint32_t dmlBiasIndex = biasIndex; - const uint32_t dmlMaskIndex = maskIndex; - const uint32_t dmlRelativePositionBiasIndex = relativePositionBiasIndex; + constexpr uint32_t dmlInputIndex = inputIndex; + constexpr uint32_t dmlWeightsIndex = weightsIndex; + constexpr uint32_t dmlBiasIndex = biasIndex; + constexpr uint32_t dmlMaskIndex = maskIndex; + constexpr uint32_t dmlAttentionBiasIndex = attentionBiasIndex; const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1; - const bool hasRelativePositionBias = kernelCreationContext.IsInputValid(relativePositionBiasIndex); + const bool hasAttentionBias = kernelCreationContext.IsInputValid(attentionBiasIndex); DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1); @@ -188,13 +188,14 @@ class DmlOperatorAttention : public DmlOperator } } - if (hasRelativePositionBias) + if (hasAttentionBias) { - auto relativePositionBiasTensorShape = m_inputTensorDescs[dmlRelativePositionBiasIndex].GetSizes(); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape.size() == 4); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape[0] == inputTensorShape[0]); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape[1] == numHeads); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape[2] == inputTensorShape[1]); + auto attentionBiasTensorShape = m_inputTensorDescs[dmlAttentionBiasIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape.size() == 4); + // TODO: support broadcast of attention bias on the first and second dimensions. + ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape[0] == inputTensorShape[0]); + ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape[2] == inputTensorShape[1]); } TensorDesc firstGemmOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape); @@ -346,7 +347,7 @@ class DmlOperatorAttention : public DmlOperator mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[dmlMaskIndex] : nullptr; } - mhaOperatorDesc.RelativePositionBiasTensor = hasRelativePositionBias ? &inputDescs[dmlRelativePositionBiasIndex] : nullptr; + mhaOperatorDesc.RelativePositionBiasTensor = hasAttentionBias ? &inputDescs[dmlAttentionBiasIndex] : nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); @@ -452,13 +453,13 @@ class DmlOperatorAttention : public DmlOperator } } - if (hasRelativePositionBias) + if (hasAttentionBias) { - DML_INPUT_GRAPH_EDGE_DESC relativePositionBiasToMhaEdge = {}; - relativePositionBiasToMhaEdge.GraphInputIndex = dmlRelativePositionBiasIndex; - relativePositionBiasToMhaEdge.ToNodeIndex = mhaNodeIndex; - relativePositionBiasToMhaEdge.ToNodeInputIndex = mhaRelativePositionBiasIndex; - inputEdges.push_back(relativePositionBiasToMhaEdge); + DML_INPUT_GRAPH_EDGE_DESC attentionBiasToMhaEdge = {}; + attentionBiasToMhaEdge.GraphInputIndex = dmlAttentionBiasIndex; + attentionBiasToMhaEdge.ToNodeIndex = mhaNodeIndex; + attentionBiasToMhaEdge.ToNodeInputIndex = mhaAttentionBiasIndex; + inputEdges.push_back(attentionBiasToMhaEdge); } if (hasSlicedValue) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp index cde08864ca54e..d781aea8515a6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -18,7 +18,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator valueIndex, biasIndex, maskIndex, - relativePositionBiasIndex, + attentionBiasIndex, pastKeyIndex, pastValueIndex, inputCount, @@ -34,7 +34,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator dmlStackedQueryKeyValueIndex, dmlBiasIndex, dmlMaskIndex, - dmlRelativePositionBiasIndex, + dmlAttentionBiasIndex, dmlPastKeyIndex, dmlPastValueIndex, dmlInputCount, @@ -55,7 +55,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator const bool hasValue = kernelCreationContext.IsInputValid(valueIndex) && !keyValueIsPast; const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); - const bool hasRelativePositionBias = kernelCreationContext.IsInputValid(relativePositionBiasIndex); + const bool hasAttentionBias = kernelCreationContext.IsInputValid(attentionBiasIndex); const bool hasPastKey = keyValueIsPast || (kernelCreationContext.IsInputValid(pastKeyIndex) && kernelCreationContext.GetInputTensorShape(pastKeyIndex)[2] != 0); const bool hasPastValue = keyValueIsPast || (kernelCreationContext.IsInputValid(pastValueIndex) && kernelCreationContext.GetInputTensorShape(pastValueIndex)[2] != 0); const bool hasPresentKeyOutput = kernelCreationContext.IsOutputValid(outputPresentKeyIndex); @@ -73,7 +73,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator stackedQkv ? std::optional(queryIndex) : std::nullopt, biasIndex, hasMask ? std::optional(maskIndex) : std::nullopt, - relativePositionBiasIndex, + attentionBiasIndex, hasPastKey ? std::optional(keyValueIsPast ? keyIndex : pastKeyIndex) : std::nullopt, hasPastValue ? std::optional(keyValueIsPast ? valueIndex : pastValueIndex) : std::nullopt, }; @@ -243,15 +243,16 @@ class DmlOperatorMultiHeadAttention : public DmlOperator } } - if (hasRelativePositionBias) + if (hasAttentionBias) { - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlRelativePositionBiasIndex].GetDimensionCount() == 4); - - auto relativePositionBiasSizes = m_inputTensorDescs[dmlRelativePositionBiasIndex].GetSizes(); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[0] == batchSize); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[1] == numHeads); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[2] == sequenceLength); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[3] == totalSequenceLength); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlAttentionBiasIndex].GetDimensionCount() == 4); + + auto attentionBiasSizes = m_inputTensorDescs[dmlAttentionBiasIndex].GetSizes(); + // TODO: support broadcast of attention bias on the first and second dimensions. + ML_CHECK_VALID_ARGUMENT(attentionBiasSizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(attentionBiasSizes[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(attentionBiasSizes[2] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(attentionBiasSizes[3] == totalSequenceLength); } if (hasPastKey) @@ -283,7 +284,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator mhaDesc.StackedQueryKeyValueTensor = stackedQkv ? &inputDescs[dmlStackedQueryKeyValueIndex] : nullptr; mhaDesc.BiasTensor = hasBias ? &inputDescs[dmlBiasIndex] : nullptr; mhaDesc.MaskTensor = hasMask ? &inputDescs[dmlMaskIndex] : nullptr; - mhaDesc.RelativePositionBiasTensor = hasRelativePositionBias ? &inputDescs[dmlRelativePositionBiasIndex] : nullptr; + mhaDesc.RelativePositionBiasTensor = hasAttentionBias ? &inputDescs[dmlAttentionBiasIndex] : nullptr; mhaDesc.PastKeyTensor = hasPastKey ? &inputDescs[dmlPastKeyIndex] : nullptr; mhaDesc.PastValueTensor = hasPastValue ? &inputDescs[dmlPastValueIndex] : nullptr; mhaDesc.OutputTensor = &outputDescs[outputIndex]; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index f9519b26bb4e3..d6fd83fd583de 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -89,7 +89,7 @@ class DmlOperatorQAttention : public DmlOperator mhaStackedQueryKeyValueIndex, mhaBiasIndex, mhaMaskIndex, - mhaRelativePositionBiasIndex, + mhaAttentionBiasIndex, mhaPastKeyIndex, mhaPastValueIndex, mhaInputCount, @@ -418,7 +418,7 @@ class DmlOperatorQAttention : public DmlOperator mhaOperatorDesc.RelativePositionBiasTensor = nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); - // Set MaskFilterValue to lowest float for Causal Mask + // Set MaskFilterValue to lowest float for Causal Mask mhaOperatorDesc.MaskFilterValue = unidirectional ? std::numeric_limits::lowest() : kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); mhaOperatorDesc.HeadCount = numHeads; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 7fb9fd3c8cfd5..252ce9298bda8 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -608,12 +608,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) const { return g_host_cpu.AttentionBase__CheckInputs(this, input_shape, weights_shape, bias_shape, - mask_index, past, relative_position_bias, parameters, + mask_index, past, attention_bias, parameters, max_threads_per_block, past_seq_len); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, const Tensor* past, int batch_size, int head_size, diff --git a/onnxruntime/python/tools/transformers/constants.py b/onnxruntime/python/tools/transformers/constants.py index fc8f2cc2f58d3..0da22dc149968 100644 --- a/onnxruntime/python/tools/transformers/constants.py +++ b/onnxruntime/python/tools/transformers/constants.py @@ -21,7 +21,7 @@ class AttentionInputIDs: BIAS = 2 MASK_INDEX = 3 PAST = 4 - RELATIVE_POSITION_BIAS = 5 + ATTENTION_BIAS = 5 PAST_SEQUENCE_LENGTH = 6 @@ -36,7 +36,7 @@ class MultiHeadAttentionInputIDs: VALUE = 2 BIAS = 3 KEY_PADDING_MASK = 4 - RELATIVE_POSITION_BIAS = 5 + ATTENTION_BIAS = 5 PAST_KEY = 6 PAST_VALUE = 7 diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 894e11275056e..5a26fedb5287d 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1624,7 +1624,7 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelP ] nis.extend([node.input[4] if len(node.input) > 4 else ""]) # 2D mask - nis.extend([node.input[5] if len(node.input) > 5 else ""]) # relative_position_bias + nis.extend([node.input[5] if len(node.input) > 5 else ""]) # attention_bias nis.extend([node.input[6] if len(node.input) > 6 else ""]) # past_key nis.extend([node.input[7] if len(node.input) > 7 else ""]) # past_value nis.extend(["past_sequence_length"]) # past_sequence_length diff --git a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py index 4da97f0de7bed..e854312cae826 100644 --- a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py +++ b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py @@ -184,9 +184,9 @@ def _are_attentions_supported(self) -> bool: def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None: for attention in self.attention_nodes: - relative_pos_bias = ( - attention.input[AttentionInputIDs.RELATIVE_POSITION_BIAS] - if len(attention.input) > AttentionInputIDs.RELATIVE_POSITION_BIAS + attention_bias = ( + attention.input[AttentionInputIDs.ATTENTION_BIAS] + if len(attention.input) > AttentionInputIDs.ATTENTION_BIAS else "" ) packed_attention = helper.make_node( @@ -197,7 +197,7 @@ def _replace_attention_with_packing_attention(self, token_offset: str, cumulativ attention.input[AttentionInputIDs.BIAS], token_offset, cumulative_sequence_length, - relative_pos_bias, + attention_bias, ], outputs=[attention.output[AttentionOutputIDs.OUTPUT]], name=self.model.create_node_name(Operators.PACKEDATTENTION), @@ -261,9 +261,9 @@ def _are_attentions_supported(self) -> bool: def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None: gated_relative_pos_bias_count = 0 for mha in self.attention_nodes: - relative_pos_bias = ( - mha.input[MultiHeadAttentionInputIDs.RELATIVE_POSITION_BIAS] - if len(mha.input) > MultiHeadAttentionInputIDs.RELATIVE_POSITION_BIAS + attention_bias = ( + mha.input[MultiHeadAttentionInputIDs.ATTENTION_BIAS] + if len(mha.input) > MultiHeadAttentionInputIDs.ATTENTION_BIAS else "" ) packed_mha = helper.make_node( @@ -275,7 +275,7 @@ def _replace_attention_with_packing_attention(self, token_offset: str, cumulativ mha.input[MultiHeadAttentionInputIDs.BIAS], token_offset, cumulative_sequence_length, - relative_pos_bias, + attention_bias, ], outputs=[mha.output[MultiHeadAttentionOutputIDs.OUTPUT]], name=self.model.create_node_name(Operators.PACKED_MULTI_HEAD_ATTENTION), @@ -293,8 +293,8 @@ def _replace_attention_with_packing_attention(self, token_offset: str, cumulativ self.node_name_to_graph_name[packed_mha.name] = self.this_graph_name # Append token_offset input to GatedRelativePositionBias - if relative_pos_bias: - rel_pos_bias_node = self.model.get_parent(mha, MultiHeadAttentionInputIDs.RELATIVE_POSITION_BIAS) + if attention_bias: + rel_pos_bias_node = self.model.get_parent(mha, MultiHeadAttentionInputIDs.ATTENTION_BIAS) if ( rel_pos_bias_node and rel_pos_bias_node.op_type == "GatedRelativePositionBias" diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index 7384cace21a67..efdcbcfb3dcdc 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -68,7 +68,7 @@ def create_mha_node( v_matmul.output[0], "", # bias attn_mask, # key_padding_mask - add_qk, # relative_position_bias + add_qk, # attention_bias past_k, past_v, ] diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index a8e2fccdd0462..61e5fa05c66c1 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -60,7 +60,7 @@ static void RunAttentionTest( const bool disable_rocm = false, const bool disable_dml = false, std::vector qkv_sizes = {}, - const std::vector& relative_position_bias_data = {}, + const std::vector& attention_bias_data = {}, int kv_sequence_length = 0, bool past_present_share_buffer = false, bool use_scale = false, @@ -205,12 +205,12 @@ static void RunAttentionTest( } } - std::vector relative_position_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; - if (relative_position_bias_data.size() > 0) { + std::vector attention_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; + if (attention_bias_data.size() > 0) { if (use_float16) { - tester.AddInput("relative_position_bias", relative_position_bias_data_dims, ToFloat16(relative_position_bias_data)); + tester.AddInput("attention_bias", attention_bias_data_dims, ToFloat16(attention_bias_data)); } else { - tester.AddInput("relative_position_bias", relative_position_bias_data_dims, relative_position_bias_data); + tester.AddInput("attention_bias", attention_bias_data_dims, attention_bias_data); } } else { if (use_float16) { @@ -292,7 +292,7 @@ static void RunAttentionTest( const bool disable_rocm = false, const bool disable_dml = false, const std::vector qkv_sizes = {}, - const std::vector& relative_position_bias_data = {}, + const std::vector& attention_bias_data = {}, int kv_sequence_length = 0, bool past_present_share_buffer = false, bool use_scale = false, @@ -301,13 +301,13 @@ static void RunAttentionTest( batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, relative_position_bias_data, + disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias_data, kv_sequence_length, past_present_share_buffer, use_scale, do_neox_rotary); RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, relative_position_bias_data, + disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias_data, kv_sequence_length, past_present_share_buffer, use_scale, do_neox_rotary); } @@ -419,7 +419,7 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { 0, false, false, disable_rocm, false, qkv_sizes); } -TEST(AttentionTest, AttentionBatch1RelativePositionBias) { +TEST(AttentionTest, AttentionBatch1AttentionBias) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -443,7 +443,7 @@ TEST(AttentionTest, AttentionBatch1RelativePositionBias) { std::vector mask_index_data = {2L}; - std::vector relative_position_bias = { + std::vector attention_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; std::vector output_data = { @@ -457,10 +457,10 @@ TEST(AttentionTest, AttentionBatch1RelativePositionBias) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, relative_position_bias); + 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias); } -TEST(AttentionTest, AttentionBatch2RelativePositionBias) { +TEST(AttentionTest, AttentionBatch2AttentionBias) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -486,7 +486,7 @@ TEST(AttentionTest, AttentionBatch2RelativePositionBias) { std::vector mask_index_data = {2L, 2L}; - std::vector relative_position_bias = { + std::vector attention_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f, 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; @@ -503,7 +503,7 @@ TEST(AttentionTest, AttentionBatch2RelativePositionBias) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, relative_position_bias); + 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias); } TEST(AttentionTest, AttentionBatch1_Float16) { @@ -1679,7 +1679,7 @@ TEST(AttentionTest, AttentionWithNormFactor) { use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/, false /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, false /*disable_dml*/, {} /*qkv_sizes*/, - {} /*relative_position_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, + {} /*attention_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, true /*use_scale*/); } @@ -1713,7 +1713,7 @@ TEST(AttentionTest, AttentionWithNeoXRotaryEmbedding) { use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/, true /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, disable_dml, {} /*qkv_sizes*/, - {} /*relative_position_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, + {} /*attention_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, true /*use_scale*/, true /*use_neox_rotary_embedding*/); } diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc index 79e1a8f0fdc19..1ea67314f62d6 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc @@ -297,7 +297,7 @@ void GetCrossAttentionDataWithPast(AttentionTestData& data) { data.fp16_output_data = data.fp32_output_data; } -void GetSelfAttentionData_WithPast_WithRelPosBias_ForT5(AttentionTestData& data) { +void GetSelfAttentionData_WithPast_WithAttnBias_ForT5(AttentionTestData& data) { data.hidden_size = 8; data.v_hidden_size = 8; data.num_heads = 2; @@ -313,21 +313,21 @@ void GetSelfAttentionData_WithPast_WithRelPosBias_ForT5(AttentionTestData& data) AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention, }; - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.query_data", data.query_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.key_data", data.key_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.value_data", data.value_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.rel_pos_bias_data", data.rel_pos_bias_data); - data.broadcast_rel_pos_bias = false; - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_key_data", data.past_key_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_value_data", data.past_value_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.fp32_output_data", data.fp32_output_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.query_data", data.query_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.key_data", data.key_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.value_data", data.value_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.attention_bias_data", data.attention_bias_data); + data.broadcast_attention_bias = false; + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.past_key_data", data.past_key_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.past_value_data", data.past_value_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_key_data", data.present_key_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_value_data", data.present_value_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.present_key_data", data.present_key_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.present_value_data", data.present_value_data); data.is_static_kv = false; } -void GetAttentionDataCutlassRelPosBias(AttentionTestData& data) { +void GetAttentionDataCutlassAttnBias(AttentionTestData& data) { data.hidden_size = 8; data.v_hidden_size = 8; data.num_heads = 2; @@ -343,13 +343,13 @@ void GetAttentionDataCutlassRelPosBias(AttentionTestData& data) { AttentionKernelType::AttentionKernel_TrtFusedCrossAttention, AttentionKernelType::AttentionKernel_TrtFusedAttention}; - LoadTensor("AttentionDataCutlassRelPosBias.query_data", data.query_data); - LoadTensor("AttentionDataCutlassRelPosBias.key_data", data.key_data); - LoadTensor("AttentionDataCutlassRelPosBias.value_data", data.value_data); - LoadTensor("AttentionDataCutlassRelPosBias.bias_data", data.bias_data); - LoadTensor("AttentionDataCutlassRelPosBias.rel_pos_bias_data", data.rel_pos_bias_data); - data.broadcast_rel_pos_bias = false; - LoadTensor("AttentionDataCutlassRelPosBias.fp16_output_data", data.fp16_output_data); + LoadTensor("AttentionDataCutlassAttnBias.query_data", data.query_data); + LoadTensor("AttentionDataCutlassAttnBias.key_data", data.key_data); + LoadTensor("AttentionDataCutlassAttnBias.value_data", data.value_data); + LoadTensor("AttentionDataCutlassAttnBias.bias_data", data.bias_data); + LoadTensor("AttentionDataCutlassAttnBias.attention_bias_data", data.attention_bias_data); + data.broadcast_attention_bias = false; + LoadTensor("AttentionDataCutlassAttnBias.fp16_output_data", data.fp16_output_data); data.fp32_output_data = {}; data.is_static_kv = false; } @@ -417,7 +417,7 @@ void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestDat data.is_static_kv = true; } -void GetSelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias(AttentionTestData& data) { +void GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(AttentionTestData& data) { data.hidden_size = 8; data.v_hidden_size = 8; data.num_heads = 2; @@ -433,19 +433,19 @@ void GetSelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias(AttentionTestDa AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention, }; - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.query_data", data.query_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.key_data", data.key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.value_data", data.value_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.bias_data", data.bias_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.past_key_data", data.past_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.past_value_data", data.past_value_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.fp32_output_data", data.fp32_output_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.present_key_data", data.present_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.present_value_data", data.present_value_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.query_data", data.query_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.key_data", data.key_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.value_data", data.value_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.bias_data", data.bias_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.past_key_data", data.past_key_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.past_value_data", data.past_value_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.fp32_output_data", data.fp32_output_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.present_key_data", data.present_key_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.present_value_data", data.present_value_data); data.is_static_kv = false; } -void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias(AttentionTestData& data) { +void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(AttentionTestData& data) { data.hidden_size = 16; data.v_hidden_size = 16; data.num_heads = 2; @@ -461,37 +461,37 @@ void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias(Atten AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention, }; - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.query_data", data.query_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.key_data", data.key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.value_data", data.value_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.bias_data", data.bias_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.past_key_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.query_data", data.query_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.key_data", data.key_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.value_data", data.value_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.bias_data", data.bias_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.past_key_data", data.past_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.past_value_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.past_value_data", data.past_value_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.fp32_output_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_key_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.present_key_data", data.present_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_value_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.present_value_data", data.present_value_data); data.is_static_kv = false; } -void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias(AttentionTestData& data) { - GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias(data); +void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(AttentionTestData& data) { + GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(data); data.bias_data.clear(); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.past_key_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.past_key_data", data.past_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.past_value_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.past_value_data", data.past_value_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.fp32_output_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.present_key_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.present_key_data", data.present_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.present_value_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.present_value_data", data.present_value_data); data.is_static_kv = false; } @@ -535,7 +535,7 @@ void GetAttentionDataWithNeoXRotaryEmbedding(std::vector& input, LoadTensor("AttentionDataWithNeoXRotaryEmbedding.output", output); } -void GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(PackedAttentionTestData& data) { +void GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(PackedAttentionTestData& data) { data.hidden_size = 32; data.v_hidden_size = 32; data.num_heads = 1; @@ -550,19 +550,19 @@ void GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(PackedAttent data.skip_kernel_types = { AttentionKernelType::AttentionKernel_TrtFusedCrossAttention}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.query_data", data.query_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.key_data", data.key_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.value_data", data.value_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.query_data", data.query_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.key_data", data.key_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.value_data", data.value_data); data.bias_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.qkv_data", data.qkv_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.qkv_data", data.qkv_data); // Do not test fp32 data.fp32_output_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.fp16_output_data", data.fp16_output_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.fp16_output_data", data.fp16_output_data); } -void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias(PackedAttentionTestData& data) { +void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias(PackedAttentionTestData& data) { data.hidden_size = 16; data.v_hidden_size = 16; data.num_heads = 2; @@ -576,23 +576,23 @@ void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias(PackedAttention data.skip_kernel_types = { AttentionKernelType::AttentionKernel_TrtFusedCrossAttention}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.query_data", data.query_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.key_data", data.key_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.value_data", data.value_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.query_data", data.query_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.key_data", data.key_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.value_data", data.value_data); data.bias_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.qkv_data", data.qkv_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.qkv_data", data.qkv_data); // shape: batch_size, num_heads, sequence_length, sequence_length - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.rel_pos_bias_data", data.rel_pos_bias_data); - data.broadcast_rel_pos_bias = false; + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.attention_bias_data", data.attention_bias_data); + data.broadcast_attention_bias = false; // Do not test fp32 data.fp32_output_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.fp16_output_data", data.fp16_output_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.fp16_output_data", data.fp16_output_data); } -void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias(PackedAttentionTestData& data) { +void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias(PackedAttentionTestData& data) { data.hidden_size = 16; data.v_hidden_size = 16; data.num_heads = 2; @@ -606,21 +606,21 @@ void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias(Packed data.skip_kernel_types = { AttentionKernelType::AttentionKernel_TrtFusedCrossAttention}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.query_data", data.query_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.key_data", data.key_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.value_data", data.value_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.query_data", data.query_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.key_data", data.key_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.value_data", data.value_data); data.bias_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.qkv_data", data.qkv_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.qkv_data", data.qkv_data); // shape: 1, num_heads, sequence_length, sequence_length - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.rel_pos_bias_data", - data.rel_pos_bias_data); - data.broadcast_rel_pos_bias = true; + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.attention_bias_data", + data.attention_bias_data); + data.broadcast_attention_bias = true; // Do not test fp32 data.fp32_output_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.fp16_output_data", + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.fp16_output_data", data.fp16_output_data); } diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.h b/onnxruntime/test/contrib_ops/attention_op_test_helper.h index ee93cdca0cd82..b0dbe6e7b4ac7 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.h +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.h @@ -27,8 +27,8 @@ struct BaseAttentionTestData { std::vector qkv_data; std::vector bias_data; - std::vector rel_pos_bias_data; - bool broadcast_rel_pos_bias; + std::vector attention_bias_data; + bool broadcast_attention_bias; std::vector past_key_data; std::vector past_value_data; @@ -76,29 +76,29 @@ void GetCrossAttentionData_HeadSize8(AttentionTestData& data); void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data); void GetCrossAttentionDataWithPast(AttentionTestData& data); -void GetSelfAttentionData_WithPast_WithRelPosBias_ForT5(AttentionTestData& data); +void GetSelfAttentionData_WithPast_WithAttnBias_ForT5(AttentionTestData& data); void GetCrossAttentionData_DiffSequenceLengths(AttentionTestData& data); void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data); void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data); -void GetSelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias(AttentionTestData& data); -void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias(AttentionTestData& data); -void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias(AttentionTestData& data); +void GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(AttentionTestData& data); +void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(AttentionTestData& data); +void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(AttentionTestData& data); void GetCrossAttentionData_WithPastPassedInDirectly_NoMask(AttentionTestData& data); void GetCausal_EmptyPastState(std::vector& input, std::vector& output, std::vector& present); -void GetAttentionDataCutlassRelPosBias(AttentionTestData& data); +void GetAttentionDataCutlassAttnBias(AttentionTestData& data); void GetAttentionDataWithNeoXRotaryEmbedding(std::vector& input, std::vector& weights, std::vector& bias, std::vector& output); -void GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(PackedAttentionTestData& data); +void GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(PackedAttentionTestData& data); -void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias(PackedAttentionTestData& data); +void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias(PackedAttentionTestData& data); -void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias(PackedAttentionTestData& data); +void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias(PackedAttentionTestData& data); bool SkipAttentionKernel(AttentionTestData& data, AttentionKernelType kernel_type); } // namespace test diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index f0255d7ece84e..3aaf710c33db4 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -31,7 +31,7 @@ static void RunMultiHeadAttentionTest( const std::vector& kv_data, // packed_kv: [batch_size, kv_sequence_length, num_heads, 2, head_size] const std::vector& qkv_data, // packed_qkv: [batch_size, sequence_length, num_heads, 3, head_size] const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size] or empty - const std::vector& rel_pos_bias_data, // relative_position_bias: [1, num_heads, sequence_length, total_sequence_length] + const std::vector& attention_bias_data, // attention_bias: [1, num_heads, sequence_length, total_sequence_length] const std::vector& past_key_data, // past_key: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& past_value_data, // past_value: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& present_key_data, // present_key: [batch_size, num_heads, total_sequence_length, head_size] @@ -80,7 +80,7 @@ static void RunMultiHeadAttentionTest( std::vector value_dims = {batch_size, is_static_kv ? kv_sequence_length : sequence_length, v_hidden_size}; std::vector bias_dims = {hidden_size + hidden_size + v_hidden_size}; // TODO(wy): Introduce past sequence length to avoid using kv_sequence_length. - std::vector rel_pos_bias_dims = + std::vector attention_bias_dims = {1, num_heads, sequence_length, past_key_data.size() ? sequence_length + kv_sequence_length : sequence_length}; std::vector past_key_dims = {batch_size, num_heads, kv_sequence_length, hidden_size / num_heads}; std::vector past_value_dims = past_key_dims; @@ -144,8 +144,8 @@ static void RunMultiHeadAttentionTest( tester.AddOptionalInputEdge(); } - if (rel_pos_bias_data.size()) { - tester.AddInput("relative_position_bias", rel_pos_bias_dims, ToFloat16(rel_pos_bias_data)); + if (attention_bias_data.size()) { + tester.AddInput("attention_bias", attention_bias_dims, ToFloat16(attention_bias_data)); } else { tester.AddOptionalInputEdge(); } @@ -208,8 +208,8 @@ static void RunMultiHeadAttentionTest( tester.AddOptionalInputEdge(); } - if (rel_pos_bias_data.size()) { - tester.AddInput("relative_position_bias", rel_pos_bias_dims, rel_pos_bias_data); + if (attention_bias_data.size()) { + tester.AddInput("attention_bias", attention_bias_dims, attention_bias_data); } else { tester.AddOptionalInputEdge(); } @@ -276,7 +276,7 @@ static void RunMultiHeadAttentionKernel( const std::vector& kv_data, // packed_kv: [batch_size, kv_sequence_length, num_heads, 2, head_size] const std::vector& qkv_data, // packed_qkv: [batch_size, sequence_length, num_heads, 3, head_size] const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size] - const std::vector& rel_pos_bias_data, // relative_position_bias: [1, num_heads, sequence_length, total_sequence_length] + const std::vector& attention_bias_data, // attention_bias: [1, num_heads, sequence_length, total_sequence_length] const std::vector& past_key_data, // past_key: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& past_value_data, // past_value: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& present_key_data, // present_key: [batch_size, num_heads, total_sequence_length, head_size] @@ -306,7 +306,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, kv_data, qkv_data, bias_data, rel_pos_bias_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); @@ -322,7 +322,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, kv_data, qkv_data, bias_data, rel_pos_bias_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); @@ -338,7 +338,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, kv_data, qkv_data, bias_data, rel_pos_bias_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); @@ -355,7 +355,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, kv_data, qkv_data, bias_data, rel_pos_bias_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); @@ -372,7 +372,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, kv_data, qkv_data, bias_data, rel_pos_bias_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); @@ -387,7 +387,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -400,7 +400,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -411,7 +411,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu kernel_type = AttentionKernelType::AttentionKernel_Default; RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -423,7 +423,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -433,7 +433,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -444,7 +444,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -454,7 +454,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu kernel_type = AttentionKernelType::AttentionKernel_Default; RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -548,17 +548,17 @@ TEST(MultiHeadAttentionTest, CrossAttentionWithPast) { } #endif -TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithRelPosBias_ForT5) { +TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithAttnBias_ForT5) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; - GetSelfAttentionData_WithPast_WithRelPosBias_ForT5(data); + GetSelfAttentionData_WithPast_WithAttnBias_ForT5(data); RunMultiHeadAttentionTests(data, true); } -TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) { +TEST(MultiHeadAttentionTest, AttentionCutlassAttnBias) { // ROCM_GTEST_SKIP("ROCm does not support cutlass"); AttentionTestData data; - GetAttentionDataCutlassRelPosBias(data); + GetAttentionDataCutlassAttnBias(data); RunMultiHeadAttentionTests(data); } @@ -575,16 +575,16 @@ TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); } -TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoRelPosBias) { +TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) { // Whisper decoder self attention with past_kv and present_kv AttentionTestData data; - GetSelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias(data); + GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(data); RunMultiHeadAttentionTests(data); - GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias(data); + GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(data); RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); - GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias(data); + GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(data); RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); } diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py index e0cfc9d0f8e25..bdb0ffc6c50db 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py @@ -502,7 +502,7 @@ def run_cross_diff_seqlen_headsize_8(): ) -def run_self_past_present_headsize_8_nomask_norelposbias(): +def run_self_past_present_headsize_8_nomask_no_attn_bias(): hidden_dim = 16 q_head_size = 8 v_head_size = 8 @@ -554,8 +554,8 @@ def create_test_data(): print("SelfAttention_Batch2_HeadSize32_PackedQKV") run_self_batch2_headsize_32_packed_qkv() - print("SelfAttention_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias") - run_self_past_present_headsize_8_nomask_norelposbias() + print("SelfAttention_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias") + run_self_past_present_headsize_8_nomask_no_attn_bias() print("CrossAttention_DiffSequenceLengths_HeadSize8") run_cross_diff_seqlen_headsize_8() diff --git a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc index 09baf8def05f6..96c629b4616d5 100644 --- a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc @@ -30,7 +30,7 @@ static void RunPackedAttentionTest( bool use_float16, bool use_scale, std::vector qkv_sizes, - const std::vector& relative_position_bias_data) { + const std::vector& attention_bias_data) { int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); @@ -62,7 +62,7 @@ static void RunPackedAttentionTest( std::vector bias_dims = {qkv_hidden_size_sum}; std::vector token_offset_dims = {batch_size, sequence_length}; std::vector cum_seq_len_dims = {batch_size + 1}; - std::vector relative_position_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; + std::vector attention_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; std::vector output_dims = {token_count, v_hidden_size}; if (use_float16) { tester.AddInput("input", input_dims, ToFloat16(input_data)); @@ -70,8 +70,8 @@ static void RunPackedAttentionTest( tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); tester.AddInput("token_offset", token_offset_dims, token_offset); tester.AddInput("cumulative_sequence_length", cum_seq_len_dims, cumulative_sequence_length); - if (relative_position_bias_data.size() > 0) { - tester.AddInput("relative_position_bias", relative_position_bias_data_dims, ToFloat16(relative_position_bias_data)); + if (attention_bias_data.size() > 0) { + tester.AddInput("attention_bias", attention_bias_data_dims, ToFloat16(attention_bias_data)); } tester.AddOutput("output", output_dims, ToFloat16(output_data)); @@ -81,8 +81,8 @@ static void RunPackedAttentionTest( tester.AddInput("bias", bias_dims, bias_data); tester.AddInput("token_offset", token_offset_dims, token_offset); tester.AddInput("cumulative_sequence_length", cum_seq_len_dims, cumulative_sequence_length); - if (relative_position_bias_data.size() > 0) { - tester.AddInput("relative_position_bias", relative_position_bias_data_dims, relative_position_bias_data); + if (attention_bias_data.size() > 0) { + tester.AddInput("attention_bias", attention_bias_data_dims, attention_bias_data); } tester.AddOutput("output", output_dims, output_data); @@ -107,7 +107,7 @@ static void RunPackedAttentionTest( int number_of_heads, int token_count, std::vector qkv_sizes = {}, - const std::vector& relative_position_bias_data = {}) { + const std::vector& attention_bias_data = {}) { #define InvokePackedAttentionTest(use_float16, use_scale) \ RunPackedAttentionTest( \ input_data, \ @@ -124,7 +124,7 @@ static void RunPackedAttentionTest( use_float16, \ use_scale, \ qkv_sizes, \ - relative_position_bias_data); + attention_bias_data); InvokePackedAttentionTest(true, true); InvokePackedAttentionTest(true, false); @@ -172,7 +172,7 @@ TEST(PackedAttentionTest, NoPack) { batch_size * sequence_length); } -TEST(PackedAttentionTest, NoPackWithRelativePositionBias) { +TEST(PackedAttentionTest, NoPackWithAttentionBias) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -197,7 +197,7 @@ TEST(PackedAttentionTest, NoPackWithRelativePositionBias) { std::vector token_offset{0, 1, 2, 3}; std::vector cum_seq_len{0, 2, 4}; - std::vector relative_position_bias = { + std::vector attention_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f, 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; @@ -220,10 +220,10 @@ TEST(PackedAttentionTest, NoPackWithRelativePositionBias) { number_of_heads, batch_size * sequence_length, {}, - relative_position_bias); + attention_bias); } -TEST(PackedAttentionTest, PackedWithRelativePositionBias) { +TEST(PackedAttentionTest, PackedWithAttentionBias) { int batch_size = 2; int sequence_length = 4; int hidden_size = 4; @@ -249,7 +249,7 @@ TEST(PackedAttentionTest, PackedWithRelativePositionBias) { std::vector token_offset{0, 1, 4, 5, 2, 3, 6, 7}; std::vector cum_seq_len{0, 2, 4}; - std::vector relative_position_bias = { + std::vector attention_bias = { 0.2f, -0.1f, 0.f, 0.f, 0.4f, 2.5f, 0.f, 0.f, 1.6f, -1.1f, 0.f, 0.f, 0.4f, -2.5f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, @@ -279,7 +279,7 @@ TEST(PackedAttentionTest, PackedWithRelativePositionBias) { number_of_heads, 4, {}, - relative_position_bias); + attention_bias); } TEST(PackedAttentionTest, PackedBatch) { diff --git a/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc index 5f811c8cf35f6..17862c0aca6fa 100644 --- a/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc @@ -32,8 +32,8 @@ namespace test { token_count, \ use_float16, \ use_scale, \ - relative_position_bias_data, \ - broadcast_relative_position_bias); + attention_bias_data, \ + broadcast_attention_bias); static void RunPackedMultiHeadAttentionTest( const std::vector& query_data, // query: [token_count, num_heads, 3, head_size] @@ -52,8 +52,8 @@ static void RunPackedMultiHeadAttentionTest( int token_count, bool use_float16, bool use_scale, - const std::vector& relative_position_bias_data, - bool broadcast_relative_position_bias) { + const std::vector& attention_bias_data, + bool broadcast_attention_bias) { int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); @@ -73,9 +73,9 @@ static void RunPackedMultiHeadAttentionTest( std::vector bias_dims = {hidden_size + hidden_size + v_hidden_size}; std::vector token_offset_dims = {batch_size, sequence_length}; std::vector cum_seq_len_dims = {batch_size + 1}; - std::vector relative_position_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; - std::vector broadcast_relative_position_bias_data_dims = {1, number_of_heads, sequence_length, sequence_length}; - auto& rel_pos_bias_dims = (broadcast_relative_position_bias ? broadcast_relative_position_bias_data_dims : relative_position_bias_data_dims); + std::vector attention_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; + std::vector broadcast_attention_bias_data_dims = {1, number_of_heads, sequence_length, sequence_length}; + auto& rel_pos_bias_dims = (broadcast_attention_bias ? broadcast_attention_bias_data_dims : attention_bias_data_dims); std::vector output_dims = {token_count, v_hidden_size}; @@ -100,10 +100,10 @@ static void RunPackedMultiHeadAttentionTest( tester.AddInput("token_offset", token_offset_dims, token_offset); tester.AddInput("cumulative_sequence_length", cum_seq_len_dims, cumulative_sequence_length); - if (relative_position_bias_data.size() > 0) { - tester.AddInput("relative_position_bias", + if (attention_bias_data.size() > 0) { + tester.AddInput("attention_bias", rel_pos_bias_dims, - ToFloat16(relative_position_bias_data)); + ToFloat16(attention_bias_data)); } tester.AddOutput("output", output_dims, ToFloat16(output_data)); @@ -127,8 +127,8 @@ static void RunPackedMultiHeadAttentionTest( tester.AddInput("token_offset", token_offset_dims, token_offset); tester.AddInput("cumulative_sequence_length", cum_seq_len_dims, cumulative_sequence_length); - if (relative_position_bias_data.size() > 0) { - tester.AddInput("relative_position_bias", rel_pos_bias_dims, relative_position_bias_data); + if (attention_bias_data.size() > 0) { + tester.AddInput("attention_bias", rel_pos_bias_dims, attention_bias_data); } tester.AddOutput("output", output_dims, output_data); @@ -157,8 +157,8 @@ static void RunPackedMultiHeadAttentionTest( int number_of_heads, int token_count, AttentionKernelType kernel_type, - const std::vector& relative_position_bias_data = {}, - bool broadcast_relative_position_bias = false) { + const std::vector& attention_bias_data = {}, + bool broadcast_attention_bias = false) { if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ @@ -310,9 +310,9 @@ TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_NoBias_trt) { AttentionKernelType::AttentionKernel_TrtFusedAttention); } -TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_RelPosBias_cutlass) { +TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_AttnBias_cutlass) { AttentionTestData data; - GetAttentionDataCutlassRelPosBias(data); + GetAttentionDataCutlassAttnBias(data); std::vector token_offset{0, 1, 2, 3, 4, 5, 6, 7}; std::vector cum_seq_len{0, 8}; @@ -331,13 +331,13 @@ TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_RelPosBias_cutlass) { data.num_heads, data.batch_size * data.sequence_length, AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention, - data.rel_pos_bias_data, - data.broadcast_rel_pos_bias); + data.attention_bias_data, + data.broadcast_attention_bias); } -TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_RelPosBias_unfused) { +TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_AttnBias_unfused) { AttentionTestData data; - GetAttentionDataCutlassRelPosBias(data); + GetAttentionDataCutlassAttnBias(data); std::vector token_offset{0, 1, 2, 3, 4, 5, 6, 7}; std::vector cum_seq_len{0, 8}; @@ -356,13 +356,13 @@ TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_RelPosBias_unfused) { data.num_heads, data.batch_size * data.sequence_length, AttentionKernelType::AttentionKernel_Unfused, - data.rel_pos_bias_data, - data.broadcast_rel_pos_bias); + data.attention_bias_data, + data.broadcast_attention_bias); } TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_trt) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -384,7 +384,7 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_trt) { TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_cutlass) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -408,7 +408,7 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_cutlass) { TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_FlashAttention) { if (HasCudaEnvironment(800)) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -432,7 +432,7 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_FlashAttention) { TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_unfused) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -452,9 +452,9 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_unfused) { AttentionKernelType::AttentionKernel_Unfused); } -TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_RelPosBias) { +TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_AttnBias) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -472,13 +472,13 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_RelPosBias) { data.num_heads, data.token_count, AttentionKernelType::AttentionKernel_Default, - data.rel_pos_bias_data, - data.broadcast_rel_pos_bias); + data.attention_bias_data, + data.broadcast_attention_bias); } -TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastRelPosBias_cutlass) { +TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastAttnBias_cutlass) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -496,13 +496,13 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastRelPosBias_ data.num_heads, data.token_count, AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention, - data.rel_pos_bias_data, - data.broadcast_rel_pos_bias); + data.attention_bias_data, + data.broadcast_attention_bias); } -TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastRelPosBias_unfused) { +TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastAttnBias_unfused) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -520,8 +520,8 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastRelPosBias_ data.num_heads, data.token_count, AttentionKernelType::AttentionKernel_Unfused, - data.rel_pos_bias_data, - data.broadcast_rel_pos_bias); + data.attention_bias_data, + data.broadcast_attention_bias); } } // namespace test diff --git a/onnxruntime/test/contrib_ops/qordered_attention_test.cc b/onnxruntime/test/contrib_ops/qordered_attention_test.cc index 1dd0162ad722f..b7cd3948b0e76 100644 --- a/onnxruntime/test/contrib_ops/qordered_attention_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_attention_test.cc @@ -272,7 +272,7 @@ TEST(QOrderedTest, Attention_WithData_ROW_ORDER) { test_qorder.AddInput("scale_values_gemm", {}, {attn_out_scale}, true); test_qorder.AddInput("mask_index", {batch_size, sequence_len}, input_mask.data(), input_mask.size()); test_qorder.AddOptionalInputEdge(); // past - test_qorder.AddOptionalInputEdge(); // relative_position_bias + test_qorder.AddOptionalInputEdge(); // attention_bias test_qorder.AddOutput("output", {batch_size, sequence_len, hidden_size}, attn_out_q8.data(), attn_out_q8.size()); diff --git a/onnxruntime/test/python/transformers/benchmark_mha.cmd b/onnxruntime/test/python/transformers/benchmark_mha.cmd index 0a6d0c37b4a35..ba57ff40203b7 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.cmd +++ b/onnxruntime/test/python/transformers/benchmark_mha.cmd @@ -5,6 +5,14 @@ python benchmark_mha.py --use_gpu python benchmark_mha.py --use_gpu --use_cuda_graph python benchmark_mha.py --use_gpu --torch +echo "Benchmark performance on GPU without attention bias" +python benchmark_mha.py --use_gpu -b 16 + +echo "Benchmark performance on GPU with attention bias" +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 --broadcast_attn_bias_dim_1 + type benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv echo "Benchmark performance on CPU with number of threads:" diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 0c52ee690af82..50b94e7af285e 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -89,7 +89,7 @@ def __init__( past_sequence_length: int = 0, kv_sequence_length=None, max_cache_sequence_length=None, - softmax_scale: float = 0.0, + scale: float = 0.0, provider="CPUExecutionProvider", device: Optional[torch.device] = None, enable_cuda_graph: bool = False, @@ -99,7 +99,10 @@ def __init__( share_past_present_buffer: bool = False, input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH, verbose: bool = False, - has_bias: bool = False, + has_bias: bool = False, # bias for input projection + has_attn_bias: bool = False, # bias added before softmax. For example,relative position bias. + broadcast_attn_bias_dim_0: bool = False, # broadcast attention bias dimension 0 + broadcast_attn_bias_dim_1: bool = False, # broadcast attention bias dimension 1 mask_format: int = AttentionMaskFormat.Mask_None, ): self.operator = "MultiHeadAttention" @@ -111,7 +114,7 @@ def __init__( self.num_heads = num_heads self.head_size = head_size self.causal = causal - self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5)) + self.scale = scale or (1.0 / (head_size**0.5)) # Support the case that there is no past but need present output (for prompt case). self.has_past_input = has_past_input @@ -151,6 +154,22 @@ def __init__( self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H self.verbose = verbose self.has_bias = has_bias + self.has_attn_bias = has_attn_bias + self.broadcast_attn_bias_dim_0 = broadcast_attn_bias_dim_0 + self.broadcast_attn_bias_dim_1 = broadcast_attn_bias_dim_1 + + assert mask_format in [ + AttentionMaskFormat.Mask_None, + AttentionMaskFormat.Mask_1D_Key_SeqLen, + AttentionMaskFormat.Mask_2D_Key_PaddingMask, + ] + self.mask_format = mask_format + + # mask_index_q and mask_index_kv will be updated in random_inputs() if mask_format is not Mask_None. + self.mask_index_kv = torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length + self.mask_index_q = ( + torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.total_sequence_length + ) assert mask_format in [ AttentionMaskFormat.Mask_None, @@ -171,11 +190,14 @@ def __repr__(self): f"num_heads={self.num_heads}, head_size={self.head_size}, " f"kv_sequence_length={self.kv_sequence_length}, past_sequence_length={self.past_sequence_length}, " f"max_cache_sequence_length={self.max_cache_sequence_length}," - f"causal={self.causal}), softmax_scale={self.softmax_scale}, use_kv_cache={self.use_kv_cache}, " + f"causal={self.causal}), scale={self.scale}, use_kv_cache={self.use_kv_cache}, " f"share_past_present_buffer={self.share_past_present_buffer}, " f"provider={self.provider}, device={self.device}, enable_cuda_graph={self.enable_cuda_graph}, " f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}, " - f"has_bias={self.has_bias}, mask_format={self.mask_format}" + f"has_bias={self.has_bias}, mask_format={self.mask_format}, " + f"has_attn_bias={self.has_attn_bias}, " + f"broadcast_attn_bias_dim_0={self.broadcast_attn_bias_dim_0}, " + f"broadcast_attn_bias_dim_1={self.broadcast_attn_bias_dim_1}, " ) def shape_dict(self, input_format=None): @@ -235,6 +257,14 @@ def shape_dict(self, input_format=None): else: assert self.mask_format == AttentionMaskFormat.Mask_None + if self.has_attn_bias: + shapes["attn_bias"] = ( + 1 if self.broadcast_attn_bias_dim_0 else self.batch_size, + 1 if self.broadcast_attn_bias_dim_1 else self.num_heads, + self.sequence_length, + self.total_sequence_length, + ) + return shapes def symbolic_shape_dict(self, input_format=None): @@ -288,12 +318,15 @@ def symbolic_shape_dict(self, input_format=None): shapes["bias"] = (3 * self.num_heads * self.head_size,) if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen: - shapes["mask"] = (self.batch_size,) + shapes["mask"] = ("batch_size",) elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask: - shapes["mask"] = (self.batch_size, "total_sequence_length") + shapes["mask"] = ("batch_size", "total_sequence_length") else: assert self.mask_format == AttentionMaskFormat.Mask_None + if self.has_attn_bias: + shapes["attn_bias"] = ("batch_size_or_1", "num_heads_or_1", "sequence_length", "total_sequence_length") + return shapes def right_side_padding_masks(self): @@ -406,6 +439,19 @@ def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False): if mask is not None: feeds = {**feeds, "mask": mask.to(dtype=torch.int32)} # mask is int32 (not bool) for MultiHeadAttention op. + if self.has_attn_bias: + attn_bias = torch.empty( + ( + 1 if self.broadcast_attn_bias_dim_0 else self.batch_size, + 1 if self.broadcast_attn_bias_dim_1 else self.num_heads, + self.sequence_length, + self.total_sequence_length, + ), + device=self.device, + dtype=dtype, + ).normal_(mean=0, std=0.1) + feeds["attn_bias"] = attn_bias + return feeds def get_input_output_names(self): @@ -425,6 +471,9 @@ def get_input_output_names(self): if self.mask_format != AttentionMaskFormat.Mask_None: inputs = [*inputs, "mask"] + if self.has_attn_bias: + inputs = [*inputs, "attn_bias"] + if self.has_past_input: inputs = [*inputs, "past_key", "past_value"] @@ -435,7 +484,7 @@ def get_input_output_names(self): def fill_optional_mha_inputs(input_names): - inputs = ["query", "key", "value", "bias", "mask", "relative_position_bias", "past_key", "past_value"] + inputs = ["query", "key", "value", "bias", "mask", "attn_bias", "past_key", "past_value"] # Remove optional inputs that are not in input_names with empty string inputs_with_optional = [input if input in input_names else "" for input in inputs] @@ -459,7 +508,7 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use "MultiHeadAttention_0", num_heads=config.num_heads, unidirectional=int(config.causal), - scale=config.softmax_scale, + scale=config.scale, mask_filter_value=float("-inf"), domain="com.microsoft", ), @@ -581,9 +630,8 @@ def get_cpu_kernel_name(config: MultiHeadAttentionConfig) -> str: # ------------------------------------------------------------------ # Functions for benchmarking PyTorch SDPA # ------------------------------------------------------------------ -def benchmark_torch_function(func: Callable, *args, **kwargs) -> float: +def benchmark_torch_function(repeats: int, func: Callable, *args, **kwargs) -> float: warmup = 5 - repeats = 100 for _ in range(warmup): func(*args, **kwargs) @@ -608,6 +656,7 @@ def run_torch_sdpa( mask_dim: int = 2, mask_dtype=torch.bool, backend: Optional[int] = None, + repeats: int = 100, ): q_shape = (batch_size, num_heads, q_seq_len, head_size) kv_shape = (batch_size, num_heads, kv_seq_len, head_size) @@ -624,6 +673,7 @@ def run_torch_sdpa( with context: average_latency = benchmark_torch_function( + repeats, scaled_dot_product_attention, q, k, @@ -634,7 +684,22 @@ def run_torch_sdpa( return average_latency -def get_test_configs(use_gpu: bool = True): +def get_test_configs(args: argparse.Namespace): + use_gpu: bool = args.use_gpu + + if args.batch_size > 0: + run_unfused = args.sequence_length + args.past_sequence_length <= (2048 if use_gpu else 1024) + return [ + ( + args.batch_size, + args.sequence_length, + args.past_sequence_length, + args.num_heads, + args.head_size, + run_unfused, + ), + ] + if use_gpu: # (batch_size, sequence_length, past_sequence_length, num_heads, head_size, run_unfused) configs = [ @@ -708,13 +773,14 @@ def get_compute_capability(): def run_tflops_test( csv_writer: csv.DictWriter, - use_gpu: bool = True, - enable_cuda_graph: bool = False, - causal: bool = False, - has_past: bool = False, - intra_op_num_threads: int = 0, - repeats: int = 100, + args: argparse.Namespace, ): + use_gpu: bool = args.use_gpu + enable_cuda_graph: bool = args.use_cuda_graph + causal: bool = args.causal + intra_op_num_threads: int = args.intra_op_num_threads + repeats: int = args.repeats + print(f"run_tflops_test: causal={causal}") if use_gpu: @@ -725,9 +791,9 @@ def run_tflops_test( # flash attention is available for sm >= 80 sm = get_compute_capability() if sm >= 80: - backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION] + backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH] else: - backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION] + backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH] else: device_id = 0 device = torch.device("cpu") @@ -736,30 +802,31 @@ def run_tflops_test( provider = "CPUExecutionProvider" backends = [SdpaKernel.DEFAULT] - configs = get_test_configs(use_gpu) - - print("\nformat\tcausal\tprompt\tbatch\tseqlen\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") + configs = get_test_configs(args) + print("\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") for input_format in formats: for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: - for use_kv_cache in [False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - use_kv_cache=use_kv_cache, - past_sequence_length=past_sequence_length, - max_cache_sequence_length=None, - kv_sequence_length=None, - provider=provider, - enable_cuda_graph=enable_cuda_graph, - device=device, - dtype=torch.float16 if use_gpu else torch.float, - share_past_present_buffer=False, - input_format=input_format, - ) + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + use_kv_cache=past_sequence_length > 0, + past_sequence_length=past_sequence_length, + max_cache_sequence_length=None, + kv_sequence_length=None, + provider=provider, + enable_cuda_graph=enable_cuda_graph, + device=device, + dtype=torch.float16 if use_gpu else torch.float, + share_past_present_buffer=False, + input_format=input_format, + has_attn_bias=args.has_attn_bias, + broadcast_attn_bias_dim_0=args.broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=args.broadcast_attn_bias_dim_1, + ) for attention_kernel in backends: sess_options = SessionOptions() sess_options.intra_op_num_threads = intra_op_num_threads @@ -786,7 +853,11 @@ def run_tflops_test( input_dict = config.random_inputs() # warm up session - _ = measure_latency(session, input_dict) + try: + _ = measure_latency(session, input_dict) + except Exception as e: + print(f"Failed to run {kernel=} for {config=}. Exception: {e}") + continue latency_list = [] for _ in range(repeats): @@ -815,6 +886,9 @@ def run_tflops_test( "past_sequence_length": past_sequence_length, "num_heads": num_heads, "head_size": head_size, + "has_attn_bias": args.has_attn_bias, + "broadcast_attn_bias_dim_0": args.broadcast_attn_bias_dim_0, + "broadcast_attn_bias_dim_1": args.broadcast_attn_bias_dim_1, "intra_op_num_threads": intra_op_num_threads, "average_latency": average_latency, "tflops": speed, @@ -824,17 +898,20 @@ def run_tflops_test( speed = f"{speed:.2f}" if speed is not None else "NA" print( - f"{format_str}\t{causal}\t{not has_past}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" + f"{format_str}\t{causal}\t{args.has_attn_bias}\t{batch_size}\t" + f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t" f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}" ) def run_torch_test( csv_writer: csv.DictWriter, - use_gpu: bool = True, - causal: bool = False, + args: argparse.Namespace, ): - configs = get_test_configs(use_gpu) + use_gpu: bool = args.use_gpu + causal: bool = args.causal + + configs = get_test_configs(args) if use_gpu: if not torch.cuda.is_available(): @@ -886,6 +963,7 @@ def run_torch_test( device=device, dtype=dtype, backend=backend, + repeats=args.repeats, ) except RuntimeError: continue @@ -893,8 +971,9 @@ def run_torch_test( speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency) input_format = "Q,K,V" print( - f"{input_format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" - f"{0}\t{torch_latency * 1000:.2f}\t{speed:.2f}\t{backend_name}" + f"{input_format}\t{causal}\t{False}\t{batch_size}\t" + f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t" + f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}" ) row = { "use_gpu": use_gpu, @@ -906,6 +985,9 @@ def run_torch_test( "past_sequence_length": past_sequence_length, "num_heads": num_heads, "head_size": head_size, + "has_attn_bias": False, + "broadcast_attn_bias_dim_0": False, + "broadcast_attn_bias_dim_1": False, "intra_op_num_threads": torch.get_num_threads(), "average_latency": torch_latency, "tflops": speed, @@ -918,7 +1000,7 @@ def run_tflops_tests(args): features = "gpu" if args.use_gpu else "cpu" if args.causal: features += "_causal" - if args.has_past: + if args.past_sequence_length > 0: features += "_past" csv_filename = "benchmark_mha_{}_{}_{}.csv".format( features, @@ -936,6 +1018,9 @@ def run_tflops_tests(args): "past_sequence_length", "num_heads", "head_size", + "has_attn_bias", + "broadcast_attn_bias_dim_0", + "broadcast_attn_bias_dim_1", "intra_op_num_threads", "average_latency", "tflops", @@ -945,16 +1030,9 @@ def run_tflops_tests(args): csv_writer.writeheader() if args.torch: - run_torch_test(csv_writer, args.use_gpu, args.causal) + run_torch_test(csv_writer, args) else: - run_tflops_test( - csv_writer, - use_gpu=args.use_gpu, - enable_cuda_graph=args.use_cuda_graph, - causal=args.causal, - has_past=args.has_past, - intra_op_num_threads=args.intra_op_num_threads, - ) + run_tflops_test(csv_writer, args) def plot_prompt_performance( @@ -1013,7 +1091,7 @@ def benchmark( head_size=head_size, causal=False, past_sequence_length=0, - kv_sequence_length=sequence_length if input_format == InputFormats.get_name_list()[-1] else None, + kv_sequence_length=sequence_length if input_format == "Q,K',V'" else None, max_cache_sequence_length=max_seq_len, provider="CUDAExecutionProvider", enable_cuda_graph=False, @@ -1083,20 +1161,66 @@ def _parse_arguments(): ) parser.add_argument( - "--has_past", + "--causal", required=False, action="store_true", - help="whether past_sequence_length > 0", + help="test unidirectional", ) - parser.set_defaults(has_past=False) + parser.set_defaults(causal=False) parser.add_argument( - "--causal", + "-b", + "--batch_size", required=False, - action="store_true", - help="test unidirectional", + type=int, + default=0, + help="batch size", + ) + + parser.add_argument( + "-s", + "--sequence_length", + required=False, + type=int, + default=512, + help="sequence length", + ) + + parser.add_argument( + "-p", + "--past_sequence_length", + required=False, + type=int, + default=0, + help="past sequence length", + ) + + parser.add_argument( + "-n", + "--num_heads", + required=False, + type=int, + default=16, + help="number of attention heads", + ) + + parser.add_argument( + "-d", + "--head_size", + required=False, + type=int, + default=64, + help="hidden dimension per head", + ) + + parser.add_argument( + "-r", + "--repeats", + required=False, + type=int, + default=100, + help="number of repeats for performance test", ) - parser.set_defaults(causal=False) parser.add_argument( "--torch", @@ -1106,6 +1230,30 @@ def _parse_arguments(): ) parser.set_defaults(torch=False) + parser.add_argument( + "--has_attn_bias", + required=False, + action="store_true", + help="has attention bias", + ) + parser.set_defaults(has_attn_bias=False) + + parser.add_argument( + "--broadcast_attn_bias_dim_0", + required=False, + action="store_true", + help="broadcast attention bias dimension 0", + ) + parser.set_defaults(broadcast_attn_bias_dim_0=False) + + parser.add_argument( + "--broadcast_attn_bias_dim_1", + required=False, + action="store_true", + help="broadcast attention bias dimension 1", + ) + parser.set_defaults(broadcast_attn_bias_dim_1=False) + args = parser.parse_args() return args @@ -1115,9 +1263,6 @@ def _parse_arguments(): args = _parse_arguments() print(f"arguments:{args}") - if args.has_past: - assert args.causal, "--has_past need --causal specified" - if args.use_gpu: assert args.torch or not args.causal, "no causal cuda kernel in MHA op" assert torch.cuda.is_available() @@ -1126,9 +1271,9 @@ def _parse_arguments(): if args.torch: assert Version(torch.__version__) >= Version("2.3.0") - assert args.has_past is False + assert args.past_sequence_length == 0 - if args.use_gpu and not args.torch: + if args.use_gpu and args.batch_size == 0 and not args.torch: if platform.system() == "Linux": s = torch.cuda.Stream() with torch.cuda.stream(s), torch.no_grad(): diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh index 613543d0172dd..ff6dd16e698df 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.sh +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -9,6 +9,15 @@ echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" export CUDA_VISIBLE_DEVICES=0 python benchmark_mha.py --use_gpu + +echo "Benchmark BERT-Large performance on GPU without attention bias" +python benchmark_mha.py --use_gpu -b 16 + +echo "Benchmark BERT-Large performance on GPU with attention bias" +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 --broadcast_attn_bias_dim_1 + python benchmark_mha.py --use_gpu --use_cuda_graph python benchmark_mha.py --use_gpu --torch diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 5948f8b1ccfc1..5ebc02c84acb2 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -68,13 +68,26 @@ def get_bias_support(format: InputFormats): raise RuntimeError(f"Unknown format: {format}") +def get_atten_bias_support(): + atten_bias_options = [ + # (has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1) + (False, False, False), + (True, False, False), # [b, n, s_q, s_kv] + (True, True, False), # [1, n, s_q, s_kv] + (True, False, True), # [b, 1, s_q, s_kv] + (True, True, True), # [1, 1, s_q, s_kv] + ] + return atten_bias_options + + def attention_reference( head_size: int, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, + attn_bias: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, verbose: bool = False, ) -> torch.Tensor: """Reference implementation of SDPA @@ -84,8 +97,9 @@ def attention_reference( query (torch.Tensor): query in BNSH format key (torch.Tensor): key in BNSH format value (torch.Tensor): value in BNSH format - scale (Optional[float], optional): scale applied before softmax. Defaults to None. - mask (Optional[torch.Tensor], optional): attention mask. Defaults to None. + scale (Optional[float], optional): scale applied on QxK'. Defaults to None. + attn_bias : attention bias tensor added before softmax. Defaults to None. + masks : attention masks. Defaults to None. Returns: torch.Tensor: result of SDPA @@ -100,25 +114,30 @@ def attention_reference( if verbose: torch.set_printoptions(precision=6, linewidth=200, sci_mode=False) - print("query(SDPA)", query) - print("key(SDPA)", key) - print("value(SDPA)", value) + print("query(ref)", query) + print("key(ref)", key) + print("value(ref)", value) if mask is not None: print("mask", mask) # Apply multi-head attention. attn = torch.einsum("bhmd,bhnd->bhmn", query, key).float() * scale if verbose: - print("QK(SDPA)", attn) + print("QK(ref)", attn) + + if attn_bias is not None: + attn = attn + attn_bias + if verbose: + print("QK+AttnBias(ref)", attn) if mask is not None: attn = attn.masked_fill((1 - mask.int()).bool(), float("-inf")) if verbose: - print("masked QK(SDPA)", attn) + print("masked QK(ref)", attn) attn = attn.softmax(-1) if verbose: - print("Softmax(SDPA)", attn) + print("Softmax(ref)", attn) attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value) @@ -128,7 +147,7 @@ def attention_reference( torch.cuda.synchronize() if verbose: - print("result(SDPA)", result) + print("result(ref)", result) return result @@ -141,6 +160,7 @@ def mha_with_past_reference( k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None, + attn_bias: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, ): assert config.kv_sequence_length == config.sequence_length @@ -157,7 +177,7 @@ def mha_with_past_reference( present_k = torch.cat((past_k, k), dim=2) if past_k is not None else k present_v = torch.cat((past_v, v), dim=2) if past_v is not None else v - out = attention_reference(config.head_size, q, present_k, present_v, scale=scale, mask=mask) + out = attention_reference(config.head_size, q, present_k, present_v, scale=scale, attn_bias=attn_bias, mask=mask) return out, present_k, present_v @@ -185,6 +205,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): AttentionMaskFormat.Mask_1D_Key_SeqLen, AttentionMaskFormat.Mask_2D_Key_PaddingMask, ] + atten_bias_options = get_atten_bias_support() device, dtype, formats = get_provider_support_info(provider, False) if comprehensive: @@ -197,25 +218,33 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): for causal in [True, False]: for mask_format in mask_formats: for has_bias in get_bias_support(format): - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - has_bias=has_bias, - mask_format=mask_format, - ) - yield config + for ( + has_attn_bias, + broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1, + ) in atten_bias_options: + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + mask_format=mask_format, + has_attn_bias=has_attn_bias, + broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -224,6 +253,9 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] mask_format = mask_formats[i % len(mask_formats)] + has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[ + i % len(atten_bias_options) + ] for causal in [True, False]: for format in formats: for has_bias in get_bias_support(format): @@ -244,6 +276,9 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): input_format=format, has_bias=has_bias, mask_format=mask_format, + has_attn_bias=has_attn_bias, + broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, ) yield config @@ -264,6 +299,8 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): AttentionMaskFormat.Mask_2D_Key_PaddingMask, ] + atten_bias_options = get_atten_bias_support() + if comprehensive: sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory for batch_size in batch_sizes: @@ -275,28 +312,36 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): for has_past_input in [True, False]: for mask_format in mask_formats: for has_bias in get_bias_support(format): - sequence_length = 1 if has_past_input else past_sequence_length - past_seq_len = past_sequence_length if has_past_input else 0 - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_seq_len, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - has_past_input=has_past_input, - share_past_present_buffer=False, - input_format=format, - has_bias=has_bias, - mask_format=mask_format, - ) - yield config + for ( + has_attn_bias, + broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1, + ) in atten_bias_options: + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + mask_format=mask_format, + has_attn_bias=has_attn_bias, + broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -305,6 +350,9 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] mask_format = mask_formats[i % len(mask_formats)] + has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[ + i % len(atten_bias_options) + ] for causal in [True, False]: for format in formats: for has_past_input in [True, False]: @@ -329,6 +377,9 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): input_format=format, has_bias=has_bias, mask_format=mask_format, + has_attn_bias=has_attn_bias, + broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, ) yield config @@ -470,6 +521,10 @@ def parity_check_mha( k = k + bias_k v = v + bias_v + attn_bias = None + if config.has_attn_bias: + attn_bias = ref_inputs["attn_bias"] + q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -480,11 +535,13 @@ def parity_check_mha( if config.use_kv_cache: past_k = ref_inputs.get("past_key", None) past_v = ref_inputs.get("past_value", None) - out_ref, k_cache, v_cache = mha_with_past_reference(config, past_k, past_v, q, k, v, mask=mask) + out_ref, k_cache, v_cache = mha_with_past_reference( + config, past_k, past_v, q, k, v, scale=config.scale, attn_bias=attn_bias, mask=mask + ) else: - out_ref = attention_reference(config.head_size, q, k, v, mask=mask) + out_ref = attention_reference(config.head_size, q, k, v, scale=config.scale, attn_bias=attn_bias, mask=mask) - # Fill zeros for the padded kens for comparison. + # Fill zeros for the padded tokens for comparison. if config.mask_index_q is not None: for i, m in enumerate(config.mask_index_q): out[i, m:, :, :] = 0 @@ -584,35 +641,69 @@ def check_parity_with_config(i: int): ) # Create reference inputs + old_format = config.input_format config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH ref_inputs = test_inputs[i]["ref_inputs"] if verbose: print(f"Thread {i} ref inputs: {ref_inputs}") - q = ( - ref_inputs["query"] - .reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) - k = ( - ref_inputs["key"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) + + q = ref_inputs["query"].reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + k = ref_inputs["key"].reshape( + (config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size) ) - v = ( - ref_inputs["value"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) + v = ref_inputs["value"].reshape( + (config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size) ) + if "bias" in ref_inputs: + bias = ref_inputs["bias"] + bias = bias.reshape((3, config.num_heads, config.head_size)) + bias_q = bias[0, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_k = bias[1, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_v = bias[2, :, :].reshape(1, 1, config.num_heads, config.head_size) + q = q + bias_q + k = k + bias_k + v = v + bias_v + + attn_bias = None + if config.has_attn_bias: + attn_bias = ref_inputs["attn_bias"] + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + mask = merge_padding_and_causal_masks(config) k_cache = None v_cache = None if config.use_kv_cache: - past_k = ref_inputs["past_key"] - past_v = ref_inputs["past_value"] - out_ref, k_cache, v_cache = mha_with_past_reference(config, past_k, past_v, q, k, v, mask=mask) + past_k = ref_inputs.get("past_key", None) + past_v = ref_inputs.get("past_value", None) + out_ref, k_cache, v_cache = mha_with_past_reference( + config, past_k, past_v, q, k, v, scale=config.scale, attn_bias=attn_bias, mask=mask + ) else: - out_ref = attention_reference(config.head_size, q, k, v, mask=mask) + out_ref = attention_reference(config.head_size, q, k, v, scale=config.scale, attn_bias=attn_bias, mask=mask) + + # Fill zeros for the padded tokens for comparison. + if config.mask_index_q is not None: + for i, m in enumerate(config.mask_index_q): + out[i, m:, :, :] = 0 + out_ref[i, m:, :, :] = 0 + + if config.mask_index_kv is not None and config.use_kv_cache: + assert k_cache is not None + assert v_cache is not None + present_key = ort_outputs[1] + present_value = ort_outputs[2] + for i, n in enumerate(config.mask_index_kv): + k_cache[i, :, n:, :] = 0 + present_key[i, :, n:, :] = 0 + v_cache[i, :, n:, :] = 0 + present_value[i, :, n:, :] = 0 + + # Restore the input format so that it shows up in the error message correctly. + config.input_format = old_format try: numpy.testing.assert_allclose( diff --git a/onnxruntime/test/python/transformers/test_parity_neox_attention.py b/onnxruntime/test/python/transformers/test_parity_neox_attention.py index d0a308987d888..300de19dd34c2 100644 --- a/onnxruntime/test/python/transformers/test_parity_neox_attention.py +++ b/onnxruntime/test/python/transformers/test_parity_neox_attention.py @@ -89,7 +89,7 @@ def create_neox_decoder_masked_self_attention_graph( "bias", "mask_index", "past", - "", # relative_position_bias + "", # attention_bias "past_sequence_length", ], ["output", "present"], diff --git a/onnxruntime/test/python/transformers/test_parity_t5_mha.py b/onnxruntime/test/python/transformers/test_parity_t5_mha.py index c7fb398dde82e..e4f65b07c552e 100644 --- a/onnxruntime/test/python/transformers/test_parity_t5_mha.py +++ b/onnxruntime/test/python/transformers/test_parity_t5_mha.py @@ -57,7 +57,7 @@ def create_t5_mha_graph( "value" if use_present or is_static_kv else "", "", # bias "key_padding_mask" if use_mask else "", - "relative_position_bias" if use_rpb else "", + "attention_bias" if use_rpb else "", "past_key" if use_past and not is_static_kv else "", "past_value" if use_past and not is_static_kv else "", ], @@ -93,9 +93,7 @@ def create_t5_mha_graph( if use_rpb: graph_inputs.append( - helper.make_tensor_value_info( - "relative_position_bias", TensorProto.FLOAT, [1, num_heads, seq_len, rpb_length] - ) + helper.make_tensor_value_info("attention_bias", TensorProto.FLOAT, [1, num_heads, seq_len, rpb_length]) ) if use_past and not is_static_kv: @@ -170,7 +168,7 @@ def create_t5_decoder_masked_mha_graph( "key", "value", "mask_index" if is_cross_attention else "", - "relative_position_bias" if not is_cross_attention else "", + "attention_bias" if not is_cross_attention else "", "past_key" if not is_cross_attention else "", "past_value" if not is_cross_attention else "", "past_sequence_length" if not is_cross_attention else "", @@ -220,7 +218,7 @@ def create_t5_decoder_masked_mha_graph( graph_inputs.append(helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, 1, hidden_size])) graph_inputs.append( helper.make_tensor_value_info( - "relative_position_bias", TensorProto.FLOAT, [1, num_heads, 1, past_sequence_length + 1] + "attention_bias", TensorProto.FLOAT, [1, num_heads, 1, past_sequence_length + 1] ) ) # use past_sequence_length + 1 to simulate max_sequence_length @@ -558,7 +556,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if torch_key_padding_mask is not None: ort_inputs["key_padding_mask"] = np.ascontiguousarray(torch_key_padding_mask.detach().numpy()) if torch_position_bias is not None: - ort_inputs["relative_position_bias"] = np.ascontiguousarray(torch_position_bias.detach().numpy()) + ort_inputs["attention_bias"] = np.ascontiguousarray(torch_position_bias.detach().numpy()) else: torch_past_key = past_key_value[0] torch_past_value = past_key_value[1] @@ -617,7 +615,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): else: ort_inputs["key_padding_mask"] = np.ascontiguousarray(torch_key_padding_mask.detach().numpy()) if torch_position_bias is not None: - ort_inputs["relative_position_bias"] = np.ascontiguousarray(torch_position_bias.detach().numpy()) + ort_inputs["attention_bias"] = np.ascontiguousarray(torch_position_bias.detach().numpy()) ort_output = ort_session.run(None, ort_inputs) diff --git a/onnxruntime/test/testdata/attention/attention_test_data.txt b/onnxruntime/test/testdata/attention/attention_test_data.txt index c52dd4ef1988b..7c60efea1f0f6 100644 --- a/onnxruntime/test/testdata/attention/attention_test_data.txt +++ b/onnxruntime/test/testdata/attention/attention_test_data.txt @@ -2812,26 +2812,26 @@ name:CrossAttentionDataWithPast.fp32_output_data 0.4291,0.5276,0.4818,0.4645,0.4768,0.4083,0.3377,0.4315, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.query_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.query_data 0.00403503,0.08716156,-0.0358175,-0.08171791, 0.48912194,-0.22679007,-0.09093101,-0.5939322, 0.00878838,0.03355761,-0.08080226,-0.06677517, 0.55038965,-0.2720567,-0.12977877,-0.634123, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.key_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.key_data 0.2808786,0.10041683,0.15880886,0.45283064, 0.39884242,0.12596075,0.4198916,-0.0651141, 0.31678027,0.11010794,0.21594375,0.4975329, 0.436772,0.20940652,0.44072092,-0.05601776, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.value_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.value_data 0.26421773,-0.16541699,-0.0599675,0.27200517, -0.1074627,-0.4493224,-0.03694462,0.17997989, 0.27960598,-0.16643806,-0.07019104,0.29006317, -0.11640988,-0.47876123,-0.01979145,0.11468418, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.rel_pos_bias_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.attention_bias_data 0.4781123,0.82420444,0.654424,0.3995186,0.5482078, 0.55570245,0.4216576,0.46001542,0.67183703,0.41973996, @@ -2839,7 +2839,7 @@ name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.rel_pos_bias_data 0.5460559,0.31994605,0.5470492,0.5433419,0.60349935, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_key_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.past_key_data 0.34734827,0.5592256,0.5333037,0.5122027, 0.5940516,0.44744077,0.43128848,0.55360645, 0.57874715,0.29512063,0.2780432,0.4693917, @@ -2849,7 +2849,7 @@ name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_key_data 0.5352153,0.5157861,0.39744973,0.5441864, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_value_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.past_value_data 0.48998538,0.5493853,0.556647,0.7011929, 0.543909,0.5630743,0.5087797,0.3901024, 0.53116417,0.4086225,0.5320247,0.5145377, @@ -2858,12 +2858,12 @@ name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_value_data 0.52980417,0.5243695,0.6046111,0.53555113, 0.44936907,0.6010697,0.38031512,0.427301, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.fp32_output_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.fp32_output_data 0.4358,0.2708,0.3201,0.4347,0.1886,0.0845,0.2479,0.3289, 0.4157,0.2247,0.2826,0.4321,0.1874,0.1021,0.2427,0.3305, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_key_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.present_key_data 0.3473,0.5592,0.5333,0.5122, 0.5941,0.4474,0.4313,0.5536, 0.5787,0.2951,0.2780,0.4694, @@ -2877,7 +2877,7 @@ name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_key_data 0.4368,0.2094,0.4407,-0.0560, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_value_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.present_value_data 0.4900,0.5494,0.5566,0.7012, 0.5439,0.5631,0.5088,0.3901, 0.5312,0.4086,0.5320,0.5145, @@ -2891,7 +2891,7 @@ name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_value_data -0.1164,-0.4788,-0.0198,0.1147, === -name:AttentionDataCutlassRelPosBias.query_data +name:AttentionDataCutlassAttnBias.query_data -0.029273793,0.079709493,0.064531095,0.24270254,-0.28326464,0.20984903,-0.10173888,0.18373983, 0.089472905,-0.0063416883,-0.049477674,0.36512995,-0.23620239,0.1464397,0.068258412,0.31627196, @@ -2909,7 +2909,7 @@ name:AttentionDataCutlassRelPosBias.query_data 0.002485469,0.029660821,-0.043821491,0.3892332,-0.26994205,0.14530671,0.12950704,0.36185294, === -name:AttentionDataCutlassRelPosBias.key_data +name:AttentionDataCutlassAttnBias.key_data -0.32538497,0.34121913,-0.18170178,-0.015152611,0.20429322,0.25979176,0.21269324,0.0025638193, -0.24246037,0.21112341,-0.36959589,-0.16091451,0.24183474,0.18856162,0.094487116,-0.3053959, @@ -2921,7 +2921,7 @@ name:AttentionDataCutlassRelPosBias.key_data -0.35736683,0.29276621,-0.4217523,-0.20031664,0.33148992,0.26928401,0.19360018,-0.39494509, -0.28043351,0.24279942,-0.29154932,-0.13657911,0.31932494,0.3500579,0.027172565,-0.19327414, === -name:AttentionDataCutlassRelPosBias.value_data +name:AttentionDataCutlassAttnBias.value_data 0.56916672,-0.2443777,0.47111356,-0.52134115,0.010381341,0.0696759,-0.071910433,-0.35201436, 0.70809275,-0.24479815,0.41633749,-0.34744334,-0.0044222325,0.25929695,-0.087832771,-0.281232, 0.90039468,-0.28931504,0.56394172,-0.43948689,-0.05856207,0.33713666,-0.10320446,-0.38833332, @@ -2931,7 +2931,7 @@ name:AttentionDataCutlassRelPosBias.value_data 0.90039468,-0.28931504,0.56394172,-0.43948689,-0.05856207,0.33713666,-0.10320446,-0.38833332, 0.76054728,-0.29080144,0.50414616,-0.42371163,-0.047198489,0.31959397,-0.22683662,-0.30321664, === -name:AttentionDataCutlassRelPosBias.bias_data +name:AttentionDataCutlassAttnBias.bias_data -0.38124341,0.02696526,-0.11914945,-0.43795273, 0.04772711,-0.03419551,-0.30606642,0.42656231, -0.25891554,0.13431972,0.22861153,0.06360734, @@ -2939,7 +2939,7 @@ name:AttentionDataCutlassRelPosBias.bias_data 0.27079183,0.42074734,-0.40314156,-0.43726659, -0.40546918,0.06927037,0.16979086,0.41458064, === -name:AttentionDataCutlassRelPosBias.rel_pos_bias_data +name:AttentionDataCutlassAttnBias.attention_bias_data -10.808288,-10.887209,7.8799553,-4.6565766,-1.6700006,-0.033962168,7.4929152,10.944146,8.640254,-18.862164,-3.1202927,-6.3049207,3.4508536,11.722519,3.3550568,-5.4888172, -2.0828252,-13.241742,2.9868939,1.4455698,-15.262972,-10.457437,-8.4519463,-4.4281874,10.212368,-0.28622282,12.087646,6.5218501,8.1785011,13.985523,-8.2068987,5.4260745, -10.808288,-10.887209,7.8799553,-4.6565766,-1.6700006,-0.033962168,7.4929152,10.944146,8.640254,-18.862164,-3.1202927,-6.3049207,3.4508536,11.722519,3.3550568,-5.4888172, @@ -2949,7 +2949,7 @@ name:AttentionDataCutlassRelPosBias.rel_pos_bias_data -10.808288,-10.887209,7.8799553,-4.6565766,-1.6700006,-0.033962168,7.4929152,10.944146,8.640254,-18.862164,-3.1202927,-6.3049207,3.4508536,11.722519,3.3550568,-5.4888172, -2.0828252,-13.241742,2.9868939,1.4455698,-15.262972,-10.457437,-8.4519463,-4.4281874,10.212368,-0.28622282,12.087646,6.5218501,8.1785011,13.985523,-8.2068987,5.4260745, === -name:AttentionDataCutlassRelPosBias.fp16_output_data +name:AttentionDataCutlassAttnBias.fp16_output_data 1.0419922,0.13000488,0.10528564,-0.86230469,-0.45336914,0.39013672,-0.048858643,0.10571289, 0.97265625,0.17590332,0.015625,-0.79248047,-0.40917969,0.31933594,0.082763672,0.12976074, 1.1455078,0.13134766,0.15014648,-0.87451172,-0.46142578,0.40161133,0.04309082,0.042663574, @@ -3095,61 +3095,61 @@ name:CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.present_value_data 1.20772719,-0.99407929,-0.15339416,0.54562038,1.29705775,-0.28651321,-0.90150839,-1.09473300, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.query_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.query_data 0.19646919,-0.21386067,-0.27314855,0.05131477,0.21946897,-0.07689354,0.4807642,0.18482974,-0.0190681,-0.10788248,-0.15682198,0.22904971,-0.06142776,-0.4403221,-0.10195574,0.23799541, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.key_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.key_data -0.31750827,-0.32454824,0.03155137,0.03182759,0.13440096,0.34943179,0.22445532,0.11102351,0.22244338,-0.17704109,-0.13821134,-0.27173677,-0.20628595,0.13097612,-0.40789506,-0.06629883, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.value_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.value_data -0.06913724,-0.0063149,-0.07416971,-0.18773878,-0.07364869,0.39338916,0.44416002,0.00183668,0.12395295,-0.3843816,-0.18271452,-0.08517379,0.36630916,-0.24954463,-0.01696574,0.48555979, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.bias_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.bias_data 0.01948512,0.11289453,-0.37937133,0.3263408,0.10306013,0.04506801,-0.15723617,-0.19587921,-0.08297779,0.18130077,0.37545684,0.01042234,0.16931378,0.08593655,0.1249035,0.17468905,0.34234244,-0.41680501,0.26368284,-0.25633363,-0.30577704,0.07245696,-0.40428748,0.38532683, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.past_key_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.past_key_data 0.12724897,0.22341636,-0.48387079,0.09443188,0.05678519,-0.34104036,-0.34692948,0.19552953,-0.18123357,0.1919703,0.05438325,-0.11104943,0.42513249,0.34167,-0.14260243,-0.45640854, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.past_value_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.past_value_data -0.19523193,-0.10181432,0.20495883,0.49535848,-0.14408513,0.26254781,0.09317692,0.1917018,-0.34887255,-0.10112371,-0.2591441,-0.15654399,0.01312815,0.16662455,-0.39409151,-0.36910505, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.fp32_output_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.fp32_output_data -0.00033577532,-0.23549549,0.19853255,0.10450245,-0.26995566,0.37128073,0.064667389,0.29624334,0.040147364,-0.43521237,-0.096833363,-0.24481347,0.037364807,-0.0091082826,-0.40797871,0.26487666, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.present_key_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.present_key_data 0.12724897,0.22341636,-0.4838708,0.094431877,-0.40048605,-0.14324747,0.4070082,0.042249933, 0.056785189,-0.34104037,-0.34692949,0.19552954,0.30371475,0.43536833,0.34935883,0.28571257, -0.18123357,0.1919703,0.054383252,-0.11104943,0.1394656,0.0042596906,0.2372455,-0.26131442, 0.42513248,0.34167001,-0.14260243,-0.45640853,-0.03697218,0.21691267,-0.28299156,0.10839023, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.present_value_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.present_value_data -0.19523193,-0.10181432,0.20495883,0.49535847,0.27320519,-0.4231199,0.18951313,-0.4440724, -0.14408512,0.26254782,0.093176924,0.1917018,-0.37942573,0.46584612,0.039872527,0.38716352, -0.34887254,-0.10112371,-0.2591441,-0.15654399,0.46629539,-0.80118656,0.08096832,-0.34150741, 0.01312815,0.16662455,-0.39409152,-0.36910504,0.060532123,-0.17708766,-0.42125323,0.87088662, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.query_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.query_data 1.29534733,2.14051294,1.09895217,1.39164531,-0.01471180,-1.40148544,-0.50825417,0.26134527, -0.70491123,0.63738143,2.13708138,0.05667466,-0.44220763,0.85254443,2.00844359,-1.23413038, -0.08030051,-1.25450790,-0.89664006,-0.69433510,0.20943037,1.41880298,1.42875051,0.79920006, 1.57896936,-1.13204634,-0.61002654,0.43365243,0.22888106,-0.38688308,-0.45924744,0.99473029, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.key_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.key_data 0.37680483,0.15317714,0.05767500,0.37780648,-2.27755547,0.89294612,-0.85582626,0.54963046, 1.67390800,-1.06330085,-2.99566054,0.68927419,1.66056263,-0.77022851,0.15417719,0.94860524, -1.84928346,-0.52135336,0.70491475,0.37400877,0.55338752,0.52915680,0.52876079,-0.55780333, -1.49814773,0.18675917,0.31246936,-1.32707596,0.42132780,-1.69121027,0.20342645,-0.34370381, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.value_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.value_data 0.60890561,-0.88021755,1.63002241,0.86171651,1.80559230,1.26110435,-0.97890180,-1.60215497, -0.79229754,1.07830989,-0.85298145,2.76264572,0.01659799,-1.49499071,0.85316724,-2.56763911, 0.53017867,1.31909978,-1.10940945,0.68858552,-1.07115889,-2.34016919,0.48310637,-0.05351824, -0.08850761,-0.56362265,0.05224326,-2.47377181,0.44249821,-0.10389519,-0.46113095,2.81619215, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.bias_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.bias_data -0.38124341,0.02696526,-0.11914945,-0.43795273,-0.34948170,-0.19608477,0.19725692,0.39987487, 0.04772711,-0.03419551,-0.30606642,0.42656231,-0.23178342,-0.13692456,-0.04889601,0.48739988, 0.27079183,0.42074734,-0.40314156,-0.43726659,0.27376485,-0.38174152,-0.43700469,0.38040614, @@ -3158,28 +3158,28 @@ name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.bias_dat 0.34785229,0.00531715,-0.35168743,-0.11641458,0.39196932,0.44535065,0.43545735,0.15593112, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.past_key_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.past_key_data -1.00657940,-0.46509427,-1.65118766,-0.17705369,1.71204090,0.53921354,-1.67056096,0.42517155, -2.00129080,1.26244307,0.28864837,1.38792157,-0.59647840,-1.18904924,0.58950418,-2.26774645, 1.88496518,0.59231639,0.33360308,-1.23532701,0.10543400,-1.77481365,-0.79397631,-0.22495472, -0.26800078,-0.20456636,1.43141091,1.55566478,-0.22702518,1.75312757,-1.29037595,-0.95538902, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.past_value_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.past_value_data 3.18056512,0.13370860,-2.20253444,2.30826044,0.86762893,-1.91499686,2.18277764,0.53384149, -0.43230706,0.49148068,-0.29957789,-3.56583714,-1.46747136,-0.40299624,1.78018796,2.84104395, -0.68692255,1.25688624,-0.42734757,-1.03185725,0.47858545,1.18466282,-1.06095874,-0.63918531, 1.41408277,0.74389833,0.89590931,1.06388271,1.29734015,0.42640167,-0.99740052,-2.79366398, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.fp32_output_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.fp32_output_data 0.72723210,-0.54989153,1.22711349,1.26993895,1.78235006,1.12648177,-0.42493403,-1.27518260, -0.43240935,0.49647018,-0.30720428,-3.51349354,-1.45166361,-0.40844491,1.77604592,2.79678369, 0.25752395,1.53741217,-1.08321750,0.69643497,-0.78710371,-1.68901348,0.51954043,-0.00401744, 1.11207914,0.40332735,0.58328331,0.10821819,1.17628312,0.40418532,-0.74326056,-1.28571272, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_key_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.present_key_data -1.00657940,-0.46509427,-1.65118766,-0.17705369,1.71204090,0.53921354,-1.67056096,0.42517155, 0.64759666,0.57392448,-0.34546655,-0.05946010,-2.00379062,0.51120460,-1.29283094,0.93003660, -2.00129080,1.26244307,0.28864837,1.38792157,-0.59647840,-1.18904924,0.58950418,-2.26774645, @@ -3190,7 +3190,7 @@ name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_ -1.90361691,0.25602955,0.48226023,-0.91249532,0.49253359,-1.77176893,0.32437757,-0.62359041, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_value_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.present_value_data 3.18056512,0.13370860,-2.20253444,2.30826044,0.86762893,-1.91499686,2.18277764,0.53384149, 0.50323355,-0.61230683,1.54025340,1.17513633,1.86586761,1.40418029,-0.66302794,-1.44035339, -0.43230706,0.49148068,-0.29957789,-3.56583714,-1.46747136,-0.40299624,1.78018796,2.84104395, @@ -3201,28 +3201,28 @@ name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_ 0.25934470,-0.55830550,-0.29944417,-2.59018636,0.83446753,0.34145546,-0.02567360,2.97212315, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.past_key_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.past_key_data -1.27737117,-0.88584161,-1.24804604,0.26021290,1.43827605,0.92095506,-1.23355627,0.04476542, -1.59582162,1.19317269,0.11885749,0.97334087,-0.66768420,-1.10849059,0.46855307,-1.98785996, 1.61417341,0.17156902,0.73674464,-0.79806042,-0.16833085,-1.39307213,-0.35697165,-0.60536087, 0.13746840,-0.27383673,1.26162004,1.14108407,-0.29823098,1.83368623,-1.41132712,-0.67550242, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.past_value_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.past_value_data 3.28623724,-0.13420212,-2.11276555,1.99484074,0.80735362,-2.05807281,1.86690378,0.37204000, -0.78015935,0.48616353,0.05210955,-3.44942260,-1.85944068,-0.84834689,1.34473062,2.68511271, -0.58125055,0.98897558,-0.33757859,-1.34527707,0.41831014,1.04158688,-1.37683260,-0.80098683, 1.06623054,0.73858118,1.24759674,1.18029726,0.90537083,-0.01894896,-1.43285787,-2.94959521, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.fp32_output_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.fp32_output_data 0.89556247,-0.80034304,1.22928894,0.98303795,1.69871271,0.90572613,-0.67420667,-1.39078152, -0.78021139,0.48869953,0.04823331,-3.42281842,-1.85140634,-0.85111630,1.34262550,2.66261697, 0.34449580,1.26394701,-0.98046219,0.34879467,-0.82231814,-1.77519011,0.17237240,-0.17839541, 0.72679031,0.35579273,0.89621741,0.10616791,0.76930743,-0.04391927,-1.14721453,-1.25471735, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.present_key_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.present_key_data -1.27737117,-0.88584161,-1.24804604,0.26021290,1.43827605,0.92095506,-1.23355627,0.04476542, 0.37680483,0.15317714,0.05767500,0.37780648,-2.27755547,0.89294612,-0.85582626,0.54963046, -1.59582162,1.19317269,0.11885749,0.97334087,-0.66768420,-1.10849059,0.46855307,-1.98785996, @@ -3233,7 +3233,7 @@ name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.p -1.49814773,0.18675917,0.31246936,-1.32707596,0.42132780,-1.69121027,0.20342645,-0.34370381, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.present_value_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.present_value_data 3.28623724,-0.13420212,-2.11276555,1.99484074,0.80735362,-2.05807281,1.86690378,0.37204000, 0.60890561,-0.88021755,1.63002241,0.86171651,1.80559230,1.26110435,-0.97890180,-1.60215497, -0.78015935,0.48616353,0.05210955,-3.44942260,-1.85944068,-0.84834689,1.34473062,2.68511271, diff --git a/onnxruntime/test/testdata/attention/packed_multihead_attention_test_data.txt b/onnxruntime/test/testdata/attention/packed_multihead_attention_test_data.txt index 2e91cf46ce5f1..5bb83e7daa1ca 100644 --- a/onnxruntime/test/testdata/attention/packed_multihead_attention_test_data.txt +++ b/onnxruntime/test/testdata/attention/packed_multihead_attention_test_data.txt @@ -1,4 +1,4 @@ -name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.query_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.query_data -0.35420692,1.31206024,-2.80201197,2.42258096,-0.86031514,-1.44535458,-0.10832444,-2.00132895, 1.62475216,0.10978927,1.84596729,0.48908550,1.44369888,0.87542874,-1.16434252,0.52133209, 1.54848897,-2.21174526,-0.28574878,0.70815033,1.18327498,3.14097571,-0.25795099,1.89341247, @@ -14,7 +14,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.query_data -0.93303132,-0.84753871,-4.32799959,-1.94716609,-1.16980326,1.62631667,2.41053247,3.78186774, 0.26432252,-0.40396988,2.04414082,0.65150046,0.47777444,-2.57569051,0.99004912,2.47947693, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.key_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.key_data -0.04407793,1.29459429,1.05810797,1.92067695,-0.65047157,0.99029726,-1.69796586,1.15320420, -1.66444266,1.78305888,1.20582056,1.69975281,0.34572244,-0.60833001,2.59864879,-1.05330181, -1.16554165,-0.03781542,-1.13475525,0.71595150,-0.91169560,1.26686060,1.60492957,-0.53510487, @@ -30,7 +30,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.key_data 2.42824388,1.56369960,1.69934130,-0.42460468,-2.25951004,-1.18074155,3.51091242,-0.30183151, -1.83517075,-0.56233191,2.35561657,-3.63751698,-3.20001125,-1.66120780,3.23455381,-1.86251283, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.value_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.value_data -0.89167893,0.02633595,-0.84866279,1.43489110,-2.91941142,-0.20650116,1.85965109,0.45669034, 0.07678832,0.04492294,0.67326981,0.97103029,1.53470886,-1.10242307,0.86584085,-0.34770033, -1.24311507,-1.80293822,-1.01317739,-0.71518499,0.77814674,-0.59236068,-2.00310278,3.13277125, @@ -46,7 +46,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.value_data 1.14034331,-1.41539204,0.13379651,3.47018123,1.53924727,1.50004411,2.87318921,1.62624204, 0.64942807,-4.54302311,-1.50294220,-1.75212634,0.27900690,-3.05124855,3.30960631,-0.07991691, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.qkv_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.qkv_data -0.35420692,1.31206024,-2.80201197,2.42258096,-0.86031514,-1.44535458,-0.10832444,-2.00132895, 1.62475216,0.10978927,1.84596729,0.48908550,1.44369888,0.87542874,-1.16434252,0.52133209, 1.54848897,-2.21174526,-0.28574878,0.70815033,1.18327498,3.14097571,-0.25795099,1.89341247, @@ -86,7 +86,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.qkv_data 1.14034331,-1.41539204,0.13379651,3.47018123,1.53924727,1.50004411,2.87318921,1.62624204, 0.64942807,-4.54302311,-1.50294220,-1.75212634,0.27900690,-3.05124855,3.30960631,-0.07991691, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.fp16_output_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.fp16_output_data -0.89160156,0.02633667,-0.84863281,1.4345703,-2.9199219,-0.20654297,1.859375,0.45678711, 0.076782227,0.044921875,0.67333984,0.97119141,1.5351562,-1.1025391,0.86572266,-0.34765625, -1.2431641,-1.8027344,-1.0126953,-0.71533203,0.77832031,-0.59228516,-2.0039062,3.1328125, @@ -102,7 +102,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.fp16_output_dat 1.08301103,-1.26178384,0.16304730,3.16210985,1.36142719,1.32916999,2.69524455,1.45106804, 0.67150640,-4.31703520,-1.34025633,-1.59496248,0.37821823,-2.85797405,3.11096096,-0.17414713f === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.query_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.query_data -1.83615911,0.08698978,0.05601556,-1.14510250,-2.30377889,-0.39893439,0.73342341,-0.09851928, -0.45148617,-0.16055907,-1.48271382,-0.07961921,-0.65701288,-0.25778309,-0.72851723,0.86755788, @@ -111,7 +111,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.query_data -0.20033565,-1.51847255,0.95205748,0.54009491,1.19315910,0.81655478,0.87503016,0.09732430, -0.53218621,-0.11167067,0.67364228,-0.59705222,-0.24946509,0.20462716,-0.56092483,-0.65660709, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.key_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.key_data 0.86949563,-0.10868365,-0.37917313,-1.23103046,0.25640076,-1.50652349,0.71594471,0.49057019, -1.41292810,-0.19686662,1.25451696,-1.59823179,-1.16262913,0.84965342,0.61178929,-1.26162946, @@ -120,7 +120,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.key_data 0.47295785,0.65468878,-1.44158995,-0.05122741,-0.34755200,0.66963655,0.72664660,1.59155345, -1.13806772,0.70947856,-0.65793264,-0.50718778,-1.20698619,0.32613355,0.61786091,-0.34040576, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.value_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.value_data -1.19203627,0.38844836,0.68121153,0.21624038,-1.77549291,0.18574584,0.90408206,-0.22868094, -0.95558548,1.38712502,0.81038797,0.14359820,0.15352470,0.00469783,0.03943123,0.53865469, @@ -129,7 +129,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.value_data -0.15860432,-0.24945745,0.67483073,0.18782829,-0.56960964,1.16764832,-0.72244978,0.55027384, -0.37327161,1.19222152,-0.23447749,0.06147140,0.32951999,1.06427121,2.26385999,0.23828916, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.qkv_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.qkv_data -1.83615911,0.08698978,0.05601556,-1.14510250,-2.30377889,-0.39893439,0.73342341,-0.09851928, 0.86949563,-0.10868365,-0.37917313,-1.23103046,0.25640076,-1.50652349,0.71594471,0.49057019, -1.19203627,0.38844836,0.68121153,0.21624038,-1.77549291,0.18574584,0.90408206,-0.22868094, @@ -154,14 +154,14 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.qkv_data -1.13806772,0.70947856,-0.65793264,-0.50718778,-1.20698619,0.32613355,0.61786091,-0.34040576, -0.37327161,1.19222152,-0.23447749,0.06147140,0.32951999,1.06427121,2.26385999,0.23828916, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.rel_pos_bias_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.attention_bias_data 0.4781123,0.82420444,0.654424,0.3995186, 0.5482078,0.55570245,0.4216576,0.46001542, 0.4781123,0.82420444,0.654424,0.3995186, 0.5482078,0.55570245,0.4216576,0.46001542, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.fp16_output_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.fp16_output_data -1.1923828,0.38842773,0.68115234,0.21618652,-1.7753906,0.18579102,0.90429688,-0.2286377, -0.95556641,1.3867188,0.81054688,0.14355469,0.15356445,0.004699707,0.039428711,0.53857422, @@ -172,7 +172,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.fp16_output_data -0.17407227,0.57763672,-0.3046875,0.51025391,-0.097045898,0.98974609,1.0234375,0.47949219, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.query_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.query_data -1.83615911,0.08698978,0.05601556,-1.14510250,-2.30377889,-0.39893439,0.73342341,-0.09851928, -0.45148617,-0.16055907,-1.48271382,-0.07961921,-0.65701288,-0.25778309,-0.72851723,0.86755788, @@ -194,7 +194,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.query_dat -0.16418101,0.30182290,0.76461935,0.89762378,-0.70261180,1.31333566,0.86440170,-0.55341989, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.key_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.key_data 0.86949563,-0.10868365,-0.37917313,-1.23103046,0.25640076,-1.50652349,0.71594471,0.49057019, -1.41292810,-0.19686662,1.25451696,-1.59823179,-1.16262913,0.84965342,0.61178929,-1.26162946, @@ -216,7 +216,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.key_data -1.74471772,0.38858974,0.77225429,-0.47355813,0.59074765,-0.50501788,-1.72981727,-1.25862873, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.value_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.value_data -1.19203627,0.38844836,0.68121153,0.21624038,-1.77549291,0.18574584,0.90408206,-0.22868094, -0.95558548,1.38712502,0.81038797,0.14359820,0.15352470,0.00469783,0.03943123,0.53865469, @@ -238,7 +238,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.value_dat -1.04708695,1.04990900,0.61408597,0.48327276,0.61544299,-0.57864964,-0.80768973,0.39645281, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.qkv_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.qkv_data -1.83615911,0.08698978,0.05601556,-1.14510250,-2.30377889,-0.39893439,0.73342341,-0.09851928, 0.86949563,-0.10868365,-0.37917313,-1.23103046,0.25640076,-1.50652349,0.71594471,0.49057019, -1.19203627,0.38844836,0.68121153,0.21624038,-1.77549291,0.18574584,0.90408206,-0.22868094, @@ -312,7 +312,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.qkv_data -1.04708695,1.04990900,0.61408597,0.48327276,0.61544299,-0.57864964,-0.80768973,0.39645281, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.rel_pos_bias_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.attention_bias_data 0.09734076,-0.01747033,0.008497253,-0.03361112,-0.028750911,-0.017142132,-0.11563814,0.10432467, 0.057628587,0.030893803,-0.096876964,0.11924802,-0.009177148,0.05799888,-0.030559167,0.034150958, 0.07427484,0.028848544,-0.031371966,0.07186346,-0.093020484,-0.066411436,0.06858949,0.07350862, @@ -332,7 +332,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.rel_pos_b 0.013226762,-0.07403794,0.06855075,-0.06551643,-0.084110215,0.11237715,0.07026932,-0.014076158, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.fp16_output_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.fp16_output_data -1.1923828,0.38842773,0.68115234,0.21618652,-1.7753906,0.18579102,0.90429688,-0.2286377, -0.95556641,1.3867188,0.81054688,0.14355469,0.15356445,0.004699707,0.039428711,0.53857422,